[general] major overhaul of triton-c/triton-ir/triton-jit:

- Added alloc const
- Added atomics
- Pruning tuning space
- Added example for dot/conv/shift
- Bugfixes
This commit is contained in:
Philippe Tillet
2019-04-25 16:17:36 -04:00
parent 0c607c9392
commit 3413aad582
50 changed files with 2051 additions and 570 deletions

View File

@@ -1,6 +1,6 @@
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/target.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
@@ -309,7 +309,47 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand());
return builder.Insert(new LoadInst(ptr));
LoadInst *result = new LoadInst(ptr);
return builder.Insert(result);
}
if(ir::store_inst* ii = dynamic_cast<ir::store_inst*>(inst)){
Value *val = value(ii->get_value_operand());
Value *ptr = value(ii->get_pointer_operand());
builder.CreateStore(val, ptr);
return nullptr;
}
if(ir::select_inst* ii = dynamic_cast<ir::select_inst*>(inst)){
Value *pred = value(ii->get_operand(0));
Value *if_value = value(ii->get_operand(1));
Value *else_value = value(ii->get_operand(2));
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
}
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, builder, 0);
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder.SetInsertPoint(tid_0_bb);
Value *cas_ptr = value(ii->get_operand(0));
Value *cas_cmp = value(ii->get_operand(1));
Value *cas_val = value(ii->get_operand(2));
Value *old = builder.CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
old = builder.CreateExtractValue(old, {0});
builder.CreateStore(old, ptr);
builder.CreateBr(tid_0_done_bb);
builder.SetInsertPoint(tid_0_done_bb);
tgt_->add_barrier(module, builder);
Value *res = builder.CreateLoad(ptr);
return (Instruction*)res;
}
// unknown instruction
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
@@ -446,7 +486,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
if(buffer_info_->is_shared(v))
return;
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
@@ -490,20 +530,11 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
shapes2.push_back(shape->get_value());
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
if(buffer_info_->is_shared(v)){
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
// TODO - buffer info not up-to-date with references
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
if(!has_phi_user(v)){
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
}
}
// phi-node (double-buffering)
else if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
if(auto *phi = dynamic_cast<ir::phi_node*>(v)) {
BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()];
unsigned id_pre = 0, id_loop = 1;
if(phi->get_incoming_block(0) == phi->get_parent())
@@ -522,13 +553,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
for(unsigned i = 0; i < phi->get_num_incoming(); i++) {
ir::basic_block* inc_block = phi->get_incoming_block(i);
ir::value* inc_value = phi->get_incoming_value(i);
ir::value* terminator = inc_block->get_inst_list().back();
ir::instruction* terminator = inc_block->get_inst_list().back();
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)});
}
}
else
throw std::runtime_error("unknown shared memory tile");
else {
if(!has_phi_user(v)){
size_t offset = alloc_->get_offset(v);
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
ptr = builder.CreateBitCast(ptr, ptr_ty);
tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)});
}
}
}
// create distributed tile
else {
@@ -607,10 +644,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
set_mask_insert_pt(idx);
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
// store->setAlignment(16);
builder.Insert(store);
});
}
else {
if(auto *x = dynamic_cast<ir::downcast_inst*>(ins)){
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
return;
}
tile *ti = tmap_[ins];
distributed_tile* result = (distributed_tile*)ti;
if(!ins->get_type()->is_tile_ty())
@@ -727,31 +770,67 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(idx, in->get_value(idx));
});
}
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins) || (buffer_info_->is_double(ins)))
// trans
else if(dynamic_cast<ir::trans_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
in->for_each([&](indices_t idx){
indices_t out_idx = idx;
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
ti->set_value(out_idx, in->get_value(idx));
});
}
else if(buffer_info_->is_shared(ins))
return;
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
// dot
else if(auto dot = dynamic_cast<ir::dot_inst*>(ins)) {
ir::value *A = ins->get_operand(0);
ir::value *B = ins->get_operand(1);
ir::value *C = ins->get_operand(2);
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
bool AT = dot->is_a_trans();
bool BT = dot->is_b_trans();
distributed_tile *TC = (distributed_tile*)tmap_.at(C);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)});
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {idx[1], builder.getInt32(K)};
if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1)
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value();
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
}
else
{
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
distributed_tile *TB = (distributed_tile*)tmap_.at(B);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
indices_t a_idx = {idx[0], builder.getInt32(0)};
indices_t b_idx = {builder.getInt32(0), idx[1]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
std::swap(b_idx[0], b_idx[1]);
Value *a = TA->get_value(a_idx);
Value *b = TB->get_value(b_idx);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
result->set_value(idx, res);
});
}
}
// element-wise
else {
@@ -858,6 +937,7 @@ void selection::run(ir::module &src, Module &dst) {
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
sh_mem_ptr_ = sh_mem_ptr;
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);
@@ -890,7 +970,7 @@ void selection::run(ir::module &src, Module &dst) {
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::basic_block* inc_block = phi->get_incoming_block(n);
ir::value* inc_val = phi->get_incoming_value(n);
ir::value* terminator = inc_block->get_inst_list().back();
ir::instruction* terminator = inc_block->get_inst_list().back();
BasicBlock *llvm_inc_block = last_block.at(inc_block);
shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val);
bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator);
@@ -920,8 +1000,8 @@ void selection::run(ir::module &src, Module &dst) {
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
PHINode *llvm_phi = (PHINode*)llvm_value(phi, dst_builder);
Value *llvm_inc_val = llvm_value(inc_val, dst_builder);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
}
}