Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -14,7 +14,7 @@ namespace ir{
|
||||
|
||||
// attributes
|
||||
type *type::get_scalar_ty() const {
|
||||
if(is_tile_ty())
|
||||
if(is_block_ty())
|
||||
return get_tile_element_ty();
|
||||
return const_cast<type*>(this);
|
||||
}
|
||||
@@ -28,7 +28,7 @@ unsigned type::get_primitive_size_in_bits() const {
|
||||
case FP128TyID: return 128;
|
||||
case PPC_FP128TyID: return 128;
|
||||
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
|
||||
case TileTyID: return ((tile_type*)(this))->get_bitwidth();
|
||||
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@@ -37,19 +37,19 @@ unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((tile_type*)(this))->get_bitwidth(); }
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == HalfTyID) return 11;
|
||||
if (id == FloatTyID) return 24;
|
||||
if (id == HalfTyID) return 10;
|
||||
if (id == FloatTyID) return 23;
|
||||
if (id == DoubleTyID) return 53;
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
type* type::get_tile_element_ty() const {
|
||||
assert(is_tile_ty());
|
||||
assert(is_block_ty());
|
||||
return contained_tys_[0];
|
||||
}
|
||||
|
||||
@@ -62,31 +62,31 @@ type * type::get_pointer_element_ty() const {
|
||||
type *ptr_ty = get_scalar_ty();
|
||||
assert(ptr_ty->is_pointer_ty());
|
||||
type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
|
||||
if(is_tile_ty())
|
||||
return tile_type::get_same_shapes(scalar_ty, (type*)this);
|
||||
if(is_block_ty())
|
||||
return block_type::get_same_shapes(scalar_ty, (type*)this);
|
||||
return scalar_ty;
|
||||
}
|
||||
|
||||
|
||||
const type::tile_shapes_t &type::get_tile_shapes() const {
|
||||
assert(is_tile_ty());
|
||||
return ((tile_type*)this)->get_shapes();
|
||||
type::block_shapes_t type::get_block_shapes() const {
|
||||
assert(is_block_ty());
|
||||
return ((block_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_rank() const {
|
||||
return get_tile_shapes().size();
|
||||
return get_block_shapes().size();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_ranks1() const {
|
||||
int ret = 0;
|
||||
for(int s: get_tile_shapes())
|
||||
for(int s: get_block_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
unsigned type::get_tile_num_elements() const {
|
||||
const tile_shapes_t& shapes = get_tile_shapes();
|
||||
const block_shapes_t& shapes = get_block_shapes();
|
||||
unsigned result = 1;
|
||||
for(auto shape: shapes)
|
||||
result *= shape;
|
||||
@@ -112,7 +112,7 @@ bool type::is_sized() const {
|
||||
return true;
|
||||
}
|
||||
// tile types are sizes
|
||||
if(is_tile_ty())
|
||||
if(is_block_ty())
|
||||
return get_scalar_ty()->is_sized();
|
||||
return false;
|
||||
}
|
||||
@@ -160,12 +160,12 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
type* composite_type::get_type_at_index(value *) const{
|
||||
assert(is_tile_ty());
|
||||
assert(is_block_ty());
|
||||
return get_scalar_ty();
|
||||
}
|
||||
|
||||
bool composite_type::index_valid(value *idx) const{
|
||||
assert(is_tile_ty());
|
||||
assert(is_block_ty());
|
||||
return idx->get_type()->is_int_or_tileint_ty();
|
||||
}
|
||||
|
||||
@@ -173,41 +173,41 @@ bool composite_type::index_valid(value *idx) const{
|
||||
// tile_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
tile_type::tile_type(type *ty, const tile_shapes_t &shapes)
|
||||
: composite_type(ty->get_context(), TileTyID), shapes_(shapes) {
|
||||
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
||||
: composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
|
||||
contained_tys_.push_back(ty);
|
||||
}
|
||||
|
||||
bool tile_type::is_valid_elt_ty(type *ty) {
|
||||
bool block_type::is_valid_elt_ty(type *ty) {
|
||||
return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
|
||||
}
|
||||
|
||||
unsigned tile_type::get_num_elements() const {
|
||||
unsigned block_type::get_num_elements() const {
|
||||
unsigned res = 1;
|
||||
for(auto shape: shapes_)
|
||||
res *= shape;
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned tile_type::get_bitwidth() const {
|
||||
unsigned block_type::get_bitwidth() const {
|
||||
return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
|
||||
}
|
||||
|
||||
tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) {
|
||||
block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
|
||||
assert(elt_ty && "Can't get a tile of <null> type!");
|
||||
assert(shapes.size() && "Can't create a tile with empty shapes!");
|
||||
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
|
||||
// look-up
|
||||
context_impl *impl = elt_ty->get_context().p_impl.get();
|
||||
tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)];
|
||||
block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
|
||||
if(!entry)
|
||||
entry = new tile_type(elt_ty, shapes);
|
||||
entry = new block_type(elt_ty, shapes);
|
||||
return entry;
|
||||
}
|
||||
|
||||
tile_type* tile_type::get_same_shapes(type *ty, type *ref){
|
||||
assert(ref->is_tile_ty());
|
||||
return get(ty, ref->get_tile_shapes());
|
||||
block_type* block_type::get_same_shapes(type *ty, type *ref){
|
||||
assert(ref->is_block_ty());
|
||||
return get(ty, ref->get_block_shapes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Reference in New Issue
Block a user