From d10265f0549479f595003d5c39a68da2534326eb Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 10 May 2021 22:30:25 -0400 Subject: [PATCH] [CODEGEN] Bugfix for immediate offsets in inline PTX (#104) --- lib/codegen/pass.cc | 2 +- lib/codegen/selection/generator.cc | 31 ++++++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 0af65542f..77d97d941 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -91,7 +91,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, liveness.run(ir); allocation.run(ir); barriers.run(ir); - // ir::print(ir, std::cout); +// ir::print(ir, std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); ker = driver::kernel::create(&*mod, name.c_str()); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index abd7fe9fd..e7c242f23 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -550,11 +550,15 @@ void generator::visit_load_inst(ir::load_inst* x){ size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; // input ptr info GetElementPtrInst *in_gep = dyn_cast(ptr); - Value *in_base = in_gep->getPointerOperand(); - ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; - in_base = cst ? in_base : in_gep; - + size_t in_off; + if(in_gep){ + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; + ptr = cst ? in_gep->getPointerOperand() : in_gep; + } + else{ + in_off = 0; + } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; @@ -634,7 +638,7 @@ void generator::visit_load_inst(ir::load_inst* x){ // finally call inline ASM // --- InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); - std::vector args = {pred, in_base}; + std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); Value *_ret = call(_asm, args); @@ -1713,11 +1717,14 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ auto idx = idxs_[arg][i]; // input ptr info + Value *ptr = vals_[arg][idx]; + size_t in_off = 0; GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); - Value *in_base = in_gep->getPointerOperand(); - ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; - in_base = cst ? in_base : in_gep; + if(in_gep){ + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; + ptr= cst ? in_gep->getPointerOperand() : in_gep; + } // output ptr info Value* out_base = shared[i].first; int out_off = shared[i].second*dtsize; @@ -1729,9 +1736,9 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ // is_zero_false_value = cst->isZeroValue(); Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;"; - FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false); + FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false); InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true); - call(iasm, {out_base, in_base, src_size}); + call(iasm, {out_base, ptr, src_size}); } std::string asm_str = "cp.async.commit_group;";