[GENERAL] Various bugfixes
This commit is contained in:
committed by
Philippe Tillet
parent
50587bbf4b
commit
8f8d36c7a4
@@ -362,6 +362,30 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
}
|
||||
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
if(!x->get_type()->is_tile_ty()){
|
||||
Value *ptr = vmap_.at(x->get_pointer_operand());
|
||||
Value *mask = vmap_.at(x->get_mask_operand());
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
Value *result_then = builder_->CreateLoad(ptr);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
Value *result = nullptr;
|
||||
if(x->get_false_value_operand()){
|
||||
Value *result_false = vmap_.at(x->get_false_value_operand());
|
||||
result = builder_->CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
|
||||
((PHINode*)result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
result = result_then;
|
||||
vmap_[x] = result;
|
||||
return;
|
||||
}
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
auto order = layouts_->get(ptr)->get_order();
|
||||
@@ -677,6 +701,8 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
|
||||
|
||||
if(add->get_type()->is_tile_ty()){
|
||||
ir::value* ptr = add->get_operand(0);
|
||||
ir::value* val = add->get_operand(1);
|
||||
@@ -684,21 +710,36 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||
|
||||
for_each(ptr, [&](indices_t idx){
|
||||
Value *rmw_ptr = ptrs->get_value(idx);
|
||||
Value *rmw_val = vals->get_value(idx);
|
||||
Value *rmw_msk = msks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Unordered,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
// num bytes
|
||||
Type* ty = rmw_val->getType();
|
||||
size_t nbits = ty->getScalarSizeInBits();
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||
rmw_ptr = gep->getPointerOperand();
|
||||
}
|
||||
rmw_ptr = builder_->CreateBitCast(rmw_ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||
// asm string
|
||||
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||
std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;";
|
||||
std::string ty_id = nbits == 32 ? "f" : "h";
|
||||
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val});
|
||||
});
|
||||
}
|
||||
else{
|
||||
@@ -803,6 +844,7 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++)
|
||||
|
@@ -255,7 +255,6 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
@@ -264,10 +263,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cout << source << std::endl;
|
||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
#endif
|
||||
//#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -231,7 +231,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
VisitExpr(condOp->exprFalse_);
|
||||
ir::value* false_val = ret_;
|
||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||
if(!false_val->get_type()->is_tile_ty())
|
||||
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||
cond,
|
||||
|
@@ -238,8 +238,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
barriers.run(module);
|
||||
//ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
// ir::print(module, std::cout);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
return res;
|
||||
}
|
||||
@@ -364,6 +364,7 @@ std::string function::preheader() {
|
||||
|
||||
DECLARATION(float, 64, 64);
|
||||
DECLARATION(half , 64, 64);
|
||||
DECLARATION(half , 128, 128);
|
||||
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
|
Reference in New Issue
Block a user