Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -325,9 +325,9 @@ layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps,
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
if(!x->get_type()->is_block_ty())
return;
if(!y->get_type()->is_tile_ty())
if(!y->get_type()->is_block_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
@@ -364,7 +364,7 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
const auto& shapes = largest->get_type()->get_block_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
@@ -411,7 +411,7 @@ void layouts::run(ir::module &mod) {
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_tile_shapes();
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
@@ -425,8 +425,8 @@ void layouts::run(ir::module &mod) {
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes();
ir::type::block_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)