[GENERAL] Various improvements:

* Sparse einsum in triton.ops.einsum
* Hacky support for fixed-tile-size atomic-add
* Various bugfixes in parser
This commit is contained in:
Philippe Tillet
2020-10-25 11:55:58 -07:00
parent 444907589d
commit 049ab989b5
16 changed files with 574 additions and 331 deletions

View File

@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->get_order(0);
auto order = layouts_->get(ptr)->get_order();
size_t ld;
for(size_t i = 0; i < order.size(); i++){
ld = order[i];
if(ld < x->get_type()->get_tile_rank())
break;
}
//size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
if(add->get_type()->is_tile_ty()){
ir::value* ptr = add->get_operand(0);
ir::value* val = add->get_operand(1);
ir::value* msk = add->get_operand(2);
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
for_each(ptr, [&](indices_t idx){
Value *rmw_ptr = ptrs->get_value(idx);
Value *rmw_val = vals->get_value(idx);
Value *rmw_msk = msks->get_value(idx);
BasicBlock *current_bb = builder_->GetInsertBlock();
Function *parent = builder_->GetInsertBlock()->getParent();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
builder_->SetInsertPoint(mask_then_bb);
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(mask_done_bb);
builder_->SetInsertPoint(mask_done_bb);
});
}
else{
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vmap_.at(add->get_operand(0));
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
}
}
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
// std::cout << typeid(*i).name() << std::endl;
visit_value(i);
}
vmap_[block] = builder_->GetInsertBlock();
}