[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)

This commit is contained in:
Philippe Tillet
2021-08-14 21:07:01 -07:00
committed by GitHub
parent 6e7593b446
commit bb1eebb4b4
3 changed files with 45 additions and 4 deletions

View File

@@ -167,6 +167,7 @@ public:
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);

View File

@@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
return result;
}
inline Value* generator::shfl_sync(Value* acc, int32_t i){
Type* ty = acc->getType();
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return builder_->CreateBitCast(ret, ty);
}
/**
* \brief Code Generation for `reduce` (1D case)
*/
@@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
acc = !acc ? val : do_acc(acc, val);
}
// reduce within wrap
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
for(int i = 16; i > 0; i >>= 1)
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
acc = do_acc(acc, shfl_sync(acc, i));
// pointers
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
@@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
builder_->SetInsertPoint(term);
Value* ret = load(gep(base, thread));
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
Value *current = call(shfl, {ret, i32(i)});
Value *current = shfl_sync(ret, i);
ret = do_acc(ret, current);
}
store(ret, gep(base, thread));

View File

@@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
def test_reduce1d(dtype, shape, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X + tl.arange(0, meta['BLOCK']))
tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device)
# triton result
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK=shape)
# torch result
z_ref = torch.sum(x).to(dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
# test load