[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)
This commit is contained in:
@@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
return result;
|
||||
}
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
||||
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
|
||||
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
|
||||
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
|
||||
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
|
||||
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
|
||||
return builder_->CreateBitCast(ret, ty);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (1D case)
|
||||
*/
|
||||
@@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
acc = !acc ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within wrap
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
|
||||
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
|
||||
acc = do_acc(acc, shfl_sync(acc, i));
|
||||
// pointers
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
@@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
Value *current = call(shfl, {ret, i32(i)});
|
||||
Value *current = shfl_sync(ret, i);
|
||||
ret = do_acc(ret, current);
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
|
Reference in New Issue
Block a user