Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -14,7 +14,7 @@ namespace ir{
// attributes
type *type::get_scalar_ty() const {
if(is_tile_ty())
if(is_block_ty())
return get_tile_element_ty();
return const_cast<type*>(this);
}
@@ -28,7 +28,7 @@ unsigned type::get_primitive_size_in_bits() const {
case FP128TyID: return 128;
case PPC_FP128TyID: return 128;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
case TileTyID: return ((tile_type*)(this))->get_bitwidth();
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0;
}
}
@@ -37,19 +37,19 @@ unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
unsigned type::get_tile_bitwidth() const
{ return ((tile_type*)(this))->get_bitwidth(); }
{ return ((block_type*)(this))->get_bitwidth(); }
unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == HalfTyID) return 11;
if (id == FloatTyID) return 24;
if (id == HalfTyID) return 10;
if (id == FloatTyID) return 23;
if (id == DoubleTyID) return 53;
throw std::runtime_error("unreachable");
}
type* type::get_tile_element_ty() const {
assert(is_tile_ty());
assert(is_block_ty());
return contained_tys_[0];
}
@@ -62,31 +62,31 @@ type * type::get_pointer_element_ty() const {
type *ptr_ty = get_scalar_ty();
assert(ptr_ty->is_pointer_ty());
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
if(is_tile_ty())
return tile_type::get_same_shapes(scalar_ty, (type*)this);
if(is_block_ty())
return block_type::get_same_shapes(scalar_ty, (type*)this);
return scalar_ty;
}
const type::tile_shapes_t &type::get_tile_shapes() const {
assert(is_tile_ty());
return ((tile_type*)this)->get_shapes();
type::block_shapes_t type::get_block_shapes() const {
assert(is_block_ty());
return ((block_type*)this)->get_shapes();
}
const size_t type::get_tile_rank() const {
return get_tile_shapes().size();
return get_block_shapes().size();
}
const size_t type::get_tile_ranks1() const {
int ret = 0;
for(int s: get_tile_shapes())
for(int s: get_block_shapes())
ret += s > 1;
return ret;
}
unsigned type::get_tile_num_elements() const {
const tile_shapes_t& shapes = get_tile_shapes();
const block_shapes_t& shapes = get_block_shapes();
unsigned result = 1;
for(auto shape: shapes)
result *= shape;
@@ -112,7 +112,7 @@ bool type::is_sized() const {
return true;
}
// tile types are sizes
if(is_tile_ty())
if(is_block_ty())
return get_scalar_ty()->is_sized();
return false;
}
@@ -160,12 +160,12 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
//===----------------------------------------------------------------------===//
type* composite_type::get_type_at_index(value *) const{
assert(is_tile_ty());
assert(is_block_ty());
return get_scalar_ty();
}
bool composite_type::index_valid(value *idx) const{
assert(is_tile_ty());
assert(is_block_ty());
return idx->get_type()->is_int_or_tileint_ty();
}
@@ -173,41 +173,41 @@ bool composite_type::index_valid(value *idx) const{
// tile_type class
//===----------------------------------------------------------------------===//
tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
block_type::block_type(type *ty, const block_shapes_t &shapes)
: composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
contained_tys_.push_back(ty);
}
bool tile_type::is_valid_elt_ty(type *ty) {
bool block_type::is_valid_elt_ty(type *ty) {
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
}
unsigned tile_type::get_num_elements() const {
unsigned block_type::get_num_elements() const {
unsigned res = 1;
for(auto shape: shapes_)
res *= shape;
return res;
}
unsigned tile_type::get_bitwidth() const {
unsigned block_type::get_bitwidth() const {
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
}
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
assert(elt_ty && "Can't get a tile of <null> type!");
assert(shapes.size() && "Can't create a tile with empty shapes!");
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
// look-up
context_impl *impl = elt_ty->get_context().p_impl.get();
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
if(!entry)
entry = new tile_type(elt_ty, shapes);
entry = new block_type(elt_ty, shapes);
return entry;
}
tile_type* tile_type::get_same_shapes(type *ty, type *ref){
assert(ref->is_tile_ty());
return get(ty, ref->get_tile_shapes());
block_type* block_type::get_same_shapes(type *ty, type *ref){
assert(ref->is_block_ty());
return get(ty, ref->get_block_shapes());
}
//===----------------------------------------------------------------------===//