[CODEGEN] Bugfix for immediate offsets in inline PTX (#104)

This commit is contained in:
Philippe Tillet
2021-05-10 22:30:25 -04:00
committed by Philippe Tillet
parent 1e844ba78d
commit d10265f054
2 changed files with 20 additions and 13 deletions

View File

@@ -91,7 +91,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
liveness.run(ir); liveness.run(ir);
allocation.run(ir); allocation.run(ir);
barriers.run(ir); barriers.run(ir);
// ir::print(ir, std::cout); // ir::print(ir, std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm)); mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str()); ker = driver::kernel::create(&*mod, name.c_str());

View File

@@ -550,11 +550,15 @@ void generator::visit_load_inst(ir::load_inst* x){
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
// input ptr info // input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr); GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
Value *in_base = in_gep->getPointerOperand(); size_t in_off;
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin()); if(in_gep){
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_base = cst ? in_base : in_gep; in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr = cst ? in_gep->getPointerOperand() : in_gep;
}
else{
in_off = 0;
}
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
size_t nbits = dtsize*8; size_t nbits = dtsize*8;
@@ -634,7 +638,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// finally call inline ASM // finally call inline ASM
// --- // ---
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, in_base}; std::vector<Value*> args = {pred, ptr};
for(Value *v: others) for(Value *v: others)
args.push_back(v); args.push_back(v);
Value *_ret = call(_asm, args); Value *_ret = call(_asm, args);
@@ -1713,11 +1717,14 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i]; auto idx = idxs_[arg][i];
// input ptr info // input ptr info
Value *ptr = vals_[arg][idx];
size_t in_off = 0;
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]); GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand(); if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin()); ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
in_base = cst ? in_base : in_gep; ptr= cst ? in_gep->getPointerOperand() : in_gep;
}
// output ptr info // output ptr info
Value* out_base = shared[i].first; Value* out_base = shared[i].first;
int out_off = shared[i].second*dtsize; int out_off = shared[i].second*dtsize;
@@ -1729,9 +1736,9 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
// is_zero_false_value = cst->isZeroValue(); // is_zero_false_value = cst->isZeroValue();
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;"; std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false); FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false);
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true); InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
call(iasm, {out_base, in_base, src_size}); call(iasm, {out_base, ptr, src_size});
} }
std::string asm_str = "cp.async.commit_group;"; std::string asm_str = "cp.async.commit_group;";