Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -150,7 +150,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
void generator::visit_value(ir::value* v) {
|
||||
if(!seen_.insert(v).second)
|
||||
return;
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
if(v->get_type()->is_block_ty()){
|
||||
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
|
||||
auto double_buffer = layout->get_double_buffer();
|
||||
// offset
|
||||
@@ -384,7 +384,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_tile_ty()){
|
||||
if(op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
|
||||
@@ -407,7 +407,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
|
||||
Instruction *then_term;
|
||||
Instruction *else_term;
|
||||
builder_->SetInsertPoint(_ret->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(then_term);
|
||||
Value* then_ret = load(ptr);
|
||||
builder_->SetInsertPoint(else_term);
|
||||
@@ -441,7 +444,7 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_tile_ty()){
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(x->get_pointer_operand());
|
||||
size_t aln = alignment_->get(ptr_op, ord[0]);
|
||||
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
|
||||
@@ -461,7 +464,10 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
if(mx){
|
||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||
builder_->SetInsertPoint(no_op->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
store(val, ptr);
|
||||
builder_->SetInsertPoint(no_op);
|
||||
@@ -501,13 +507,15 @@ void generator::visit_splat_inst(ir::splat_inst* x) {
|
||||
*/
|
||||
void generator::visit_broadcast_inst(ir::broadcast_inst* x) {
|
||||
ir::value* op = x->get_operand(0);
|
||||
const auto& shape = op->get_type()->get_tile_shapes();
|
||||
const auto& shape = op->get_type()->get_block_shapes();
|
||||
for(auto out_idx: idxs_.at(x)){
|
||||
indices_t in_idx = out_idx;
|
||||
for(size_t k = 0; k < in_idx.size(); k++)
|
||||
in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k];
|
||||
vals_[x][out_idx] = vals_[op][in_idx];
|
||||
}
|
||||
// for(size_t i = 0; i < idxs_.at(x).size(); i++)
|
||||
// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -527,9 +535,9 @@ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `get_num_program`
|
||||
* \brief Code Generation for `get_num_programs`
|
||||
*/
|
||||
void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) {
|
||||
void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) {
|
||||
Module *module = builder_->GetInsertBlock()->getModule();
|
||||
Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis());
|
||||
vals_[np][{}] = ret;
|
||||
@@ -621,7 +629,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
//TODO: clean-up
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
if(add->get_type()->is_block_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
@@ -706,9 +714,9 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
//TODO: clean-up
|
||||
void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
// shapes
|
||||
auto shape_c = C->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_c = C->get_type()->get_block_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
// order
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
@@ -877,7 +885,7 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
*/
|
||||
//TODO: clean-up
|
||||
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||
const auto& shapes = dot->get_type()->get_block_shapes();
|
||||
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
@@ -887,8 +895,8 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
fcs[key].push_back(vals_[D][idx]);
|
||||
};
|
||||
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
|
||||
@@ -1059,9 +1067,9 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
* \brief Code Generation for FMA-based `dot` (FP32, FP64, Default)
|
||||
*/
|
||||
void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) {
|
||||
auto shape_c = C->get_type()->get_tile_shapes();
|
||||
auto shape_a = A->get_type()->get_tile_shapes();
|
||||
auto shape_b = B->get_type()->get_tile_shapes();
|
||||
auto shape_c = C->get_type()->get_block_shapes();
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline();
|
||||
@@ -1161,7 +1169,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
ir::value *D = dot->get_operand(2);
|
||||
Type *c_ty = cvt(D->get_type()->get_scalar_ty());
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
auto A_shapes = A->get_type()->get_block_shapes();
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
bool is_outer = NK == 1;
|
||||
@@ -1236,7 +1244,10 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
@@ -1359,10 +1370,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
* \brief Code Generation for `select`
|
||||
*/
|
||||
void generator::visit_select_inst(ir::select_inst* x) {
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
vals_[x][idx] = select(vals_[x->get_operand(0)][idx],
|
||||
vals_[x->get_operand(1)][idx],
|
||||
vals_[x->get_operand(2)][idx]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1370,7 +1382,7 @@ void generator::visit_select_inst(ir::select_inst* x) {
|
||||
*/
|
||||
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
ir::value *op = rc->get_operand(0);
|
||||
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
|
||||
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(rc->get_type()->get_scalar_ty());
|
||||
// layout
|
||||
@@ -1435,7 +1447,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
|
||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
auto shapes = x->get_type()->get_block_shapes();
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
std::map<std::pair<int, int>, Value*> tmp;
|
||||
@@ -1520,7 +1532,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
auto shapes = cts->get_type()->get_tile_shapes();
|
||||
auto shapes = cts->get_type()->get_block_shapes();
|
||||
|
||||
// store to shared
|
||||
Value *current = nullptr;
|
||||
@@ -1901,13 +1913,13 @@ void generator::visit_argument(ir::argument* arg) {
|
||||
|
||||
void generator::init_idx(ir::value *v) {
|
||||
idxs_[v].clear();
|
||||
if(!v->get_type()->is_tile_ty()){
|
||||
if(!v->get_type()->is_block_ty()){
|
||||
idxs_[v].push_back({});
|
||||
return;
|
||||
}
|
||||
if(layouts_->get(v)->to_shared())
|
||||
return;
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
const auto &shapes = v->get_type()->get_block_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> ord(rank);
|
||||
|
Reference in New Issue
Block a user