[general] major overhaul of triton-c/triton-ir/triton-jit:
- Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -309,7 +309,47 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
}
|
||||
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
return builder.Insert(new LoadInst(ptr));
|
||||
LoadInst *result = new LoadInst(ptr);
|
||||
return builder.Insert(result);
|
||||
}
|
||||
if(ir::store_inst* ii = dynamic_cast<ir::store_inst*>(inst)){
|
||||
Value *val = value(ii->get_value_operand());
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
builder.CreateStore(val, ptr);
|
||||
return nullptr;
|
||||
}
|
||||
if(ir::select_inst* ii = dynamic_cast<ir::select_inst*>(inst)){
|
||||
Value *pred = value(ii->get_operand(0));
|
||||
Value *if_value = value(ii->get_operand(1));
|
||||
Value *else_value = value(ii->get_operand(2));
|
||||
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
||||
}
|
||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
|
||||
}
|
||||
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = value(ii->get_operand(0));
|
||||
Value *cas_cmp = value(ii->get_operand(1));
|
||||
Value *cas_val = value(ii->get_operand(2));
|
||||
Value *old = builder.CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
old = builder.CreateExtractValue(old, {0});
|
||||
builder.CreateStore(old, ptr);
|
||||
builder.CreateBr(tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *res = builder.CreateLoad(ptr);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
// unknown instruction
|
||||
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
||||
@@ -446,7 +486,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
|
||||
if(buffer_info_->is_shared(v))
|
||||
return;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
@@ -490,20 +530,11 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
shapes2.push_back(shape->get_value());
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
|
||||
if(buffer_info_->is_shared(v)){
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// TODO - buffer info not up-to-date with references
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
|
||||
if(!has_phi_user(v)){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
}
|
||||
}
|
||||
// phi-node (double-buffering)
|
||||
else if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
|
||||
unsigned id_pre = 0, id_loop = 1;
|
||||
if(phi->get_incoming_block(0) == phi->get_parent())
|
||||
@@ -522,13 +553,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(i);
|
||||
ir::value* inc_value = phi->get_incoming_value(i);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
|
||||
}
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("unknown shared memory tile");
|
||||
else {
|
||||
if(!has_phi_user(v)){
|
||||
size_t offset = alloc_->get_offset(v);
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
|
||||
}
|
||||
}
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
@@ -607,10 +644,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
|
||||
// store->setAlignment(16);
|
||||
builder.Insert(store);
|
||||
});
|
||||
}
|
||||
else {
|
||||
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
|
||||
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
|
||||
return;
|
||||
}
|
||||
tile *ti = tmap_[ins];
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
@@ -727,31 +770,67 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ti->set_value(idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins) || (buffer_info_->is_double(ins)))
|
||||
// trans
|
||||
else if(dynamic_cast<ir::trans_inst*>(ins)) {
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
in->for_each([&](indices_t idx){
|
||||
indices_t out_idx = idx;
|
||||
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
|
||||
ti->set_value(out_idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
else if(buffer_info_->is_shared(ins))
|
||||
return;
|
||||
// matrix multiplication
|
||||
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
|
||||
// dot
|
||||
else if(auto dot = dynamic_cast<ir::dot_inst*>(ins)) {
|
||||
ir::value *A = ins->get_operand(0);
|
||||
ir::value *B = ins->get_operand(1);
|
||||
ir::value *C = ins->get_operand(2);
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
bool AT = dot->is_a_trans();
|
||||
bool BT = dot->is_b_trans();
|
||||
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
|
||||
for(unsigned K = 0; K < NK; ++K){
|
||||
indices_t a_idx = {idx[0], builder.getInt32(K)};
|
||||
indices_t b_idx = {builder.getInt32(K), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
|
||||
result->for_each([&](indices_t idx){
|
||||
Value *res = TC->get_value(idx);
|
||||
indices_t a_idx = {idx[0], builder.getInt32(0)};
|
||||
indices_t b_idx = {builder.getInt32(0), idx[1]};
|
||||
if(AT)
|
||||
std::swap(a_idx[0], a_idx[1]);
|
||||
if(BT)
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
}
|
||||
// element-wise
|
||||
else {
|
||||
@@ -858,6 +937,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
sh_mem_ptr_ = sh_mem_ptr;
|
||||
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
@@ -890,7 +970,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::basic_block* inc_block = phi->get_incoming_block(n);
|
||||
ir::value* inc_val = phi->get_incoming_value(n);
|
||||
ir::value* terminator = inc_block->get_inst_list().back();
|
||||
ir::instruction* terminator = inc_block->get_inst_list().back();
|
||||
BasicBlock *llvm_inc_block = last_block.at(inc_block);
|
||||
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
|
||||
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
|
||||
@@ -920,8 +1000,8 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
});
|
||||
}
|
||||
else {
|
||||
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
|
||||
Value *llvm_inc_val = vmap_.at(inc_val);
|
||||
PHINode *llvm_phi = (PHINode*)llvm_value(phi, dst_builder);
|
||||
Value *llvm_inc_val = llvm_value(inc_val, dst_builder);
|
||||
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user