[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)

This commit is contained in:
Philippe Tillet
2021-08-14 21:07:01 -07:00
committed by GitHub
parent 6e7593b446
commit bb1eebb4b4
3 changed files with 45 additions and 4 deletions

View File

@@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
return result;
}
inline Value* generator::shfl_sync(Value* acc, int32_t i){
Type* ty = acc->getType();
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return builder_->CreateBitCast(ret, ty);
}
/**
* \brief Code Generation for `reduce` (1D case)
*/
@@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
acc = !acc ? val : do_acc(acc, val);
}
// reduce within wrap
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
for(int i = 16; i > 0; i >>= 1)
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
acc = do_acc(acc, shfl_sync(acc, i));
// pointers
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
@@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
builder_->SetInsertPoint(term);
Value* ret = load(gep(base, thread));
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
Value *current = call(shfl, {ret, i32(i)});
Value *current = shfl_sync(ret, i);
ret = do_acc(ret, current);
}
store(ret, gep(base, thread));