[CODEGEN] Bugfixes with FP32 async copy

This commit is contained in:
Philippe Tillet
2021-02-24 13:36:26 -05:00
parent 11215f0f03
commit 567a1a3d17
3 changed files with 12 additions and 11 deletions

View File

@@ -1467,26 +1467,26 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
}
shared.push_back({tmp[key], off});
}
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
int out_off = shared[i].second*dtsize;
// asm
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca";
// Value* false_value = vals_[x->get_false_value_operand()][idx];
// bool is_zero_false_value = false;
// if(Constant* cst = dyn_cast<Constant>(false_value))
// is_zero_false_value = cst->isZeroValue();
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*2), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ", $2;";
Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0));
std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;";
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType(), builder_->getInt32Ty()}, false);
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true);
call(iasm, {out_base, in_base, src_size});