[GENERAL] Various improvements:
* Sparse einsum in triton.ops.einsum * Hacky support for fixed-tile-size atomic-add * Various bugfixes in parser
This commit is contained in:
@@ -364,7 +364,14 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
auto order = layouts_->get(ptr)->get_order();
|
||||
size_t ld;
|
||||
for(size_t i = 0; i < order.size(); i++){
|
||||
ld = order[i];
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
break;
|
||||
}
|
||||
//size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
@@ -652,6 +659,31 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
ir::value* msk = add->get_operand(2);
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||
for_each(ptr, [&](indices_t idx){
|
||||
Value *rmw_ptr = ptrs->get_value(idx);
|
||||
Value *rmw_val = vals->get_value(idx);
|
||||
Value *rmw_msk = msks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else{
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(add->get_operand(0));
|
||||
@@ -670,6 +702,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
@@ -1362,8 +1395,10 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// std::cout << typeid(*i).name() << std::endl;
|
||||
visit_value(i);
|
||||
}
|
||||
vmap_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user