[BACKEND][CODEGEN] Fix reduce uint (#547)
This commit is contained in:
@@ -913,7 +913,7 @@ public:
|
|||||||
class reduce_inst: public builtin_inst {
|
class reduce_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
enum op_t{
|
enum op_t{
|
||||||
ADD, SUB, MAX, MIN,
|
ADD, SUB, MAX, MIN, UMAX, UMIN,
|
||||||
FADD, FSUB, FMAX, FMIN,
|
FADD, FSUB, FMAX, FMIN,
|
||||||
XOR
|
XOR
|
||||||
};
|
};
|
||||||
|
@@ -119,6 +119,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||||
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
|
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
|
||||||
|
#define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__)
|
||||||
|
#define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__)
|
||||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||||
@@ -2498,6 +2500,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
case ir::reduce_inst::SUB: return sub(x, y);
|
case ir::reduce_inst::SUB: return sub(x, y);
|
||||||
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
||||||
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
||||||
|
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
|
||||||
|
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
|
||||||
case ir::reduce_inst::FADD: return fadd(x, y);
|
case ir::reduce_inst::FADD: return fadd(x, y);
|
||||||
case ir::reduce_inst::FSUB: return fsub(x, y);
|
case ir::reduce_inst::FSUB: return fsub(x, y);
|
||||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||||
@@ -2510,9 +2514,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
Value *neutral;
|
Value *neutral;
|
||||||
switch(op) {
|
switch(op) {
|
||||||
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
|
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
|
||||||
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
|
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
|
||||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||||
|
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
|
||||||
|
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
|
||||||
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
|
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
|
||||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
||||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||||
|
@@ -571,6 +571,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.value("FADD", ir::reduce_inst::FADD)
|
.value("FADD", ir::reduce_inst::FADD)
|
||||||
.value("MIN", ir::reduce_inst::MIN)
|
.value("MIN", ir::reduce_inst::MIN)
|
||||||
.value("MAX", ir::reduce_inst::MAX)
|
.value("MAX", ir::reduce_inst::MAX)
|
||||||
|
.value("UMIN", ir::reduce_inst::UMIN)
|
||||||
|
.value("UMAX", ir::reduce_inst::UMAX)
|
||||||
.value("FMIN", ir::reduce_inst::FMIN)
|
.value("FMIN", ir::reduce_inst::FMIN)
|
||||||
.value("FMAX", ir::reduce_inst::FMAX)
|
.value("FMAX", ir::reduce_inst::FMAX)
|
||||||
.value("XOR", ir::reduce_inst::XOR);
|
.value("XOR", ir::reduce_inst::XOR);
|
||||||
|
@@ -688,60 +688,78 @@ def test_f16_to_f8_rounding():
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape",
|
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||||
[(dtype, shape)
|
[(op, dtype, shape)
|
||||||
|
for op in ['min', 'max', 'sum']
|
||||||
for dtype in dtypes
|
for dtype in dtypes
|
||||||
for shape in [32, 64, 128, 512]])
|
for shape in [32, 64, 128, 512]])
|
||||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BLOCK: tl.constexpr):
|
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||||
x = tl.load(X + tl.arange(0, BLOCK))
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
tl.store(Z, tl.sum(x, axis=0))
|
tl.store(Z, GENERATE_TEST_HERE)
|
||||||
|
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
||||||
|
# input
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||||
x[:] = 1
|
|
||||||
# numpy result
|
|
||||||
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
|
||||||
# triton result
|
|
||||||
x_tri = to_triton(x, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||||
|
# numpy result
|
||||||
|
z_ref = numpy_op(x).astype(getattr(np, dtype_str))
|
||||||
|
# triton result
|
||||||
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
|
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
|
||||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||||
# compare
|
# compare
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
if op == 'sum':
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
else:
|
||||||
|
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||||
|
|
||||||
|
|
||||||
reduce_configs1 = [
|
reduce_configs1 = [
|
||||||
(dtype, (1, 1024), axis) for dtype in ['float32', 'uint32']
|
(op, dtype, (1, 1024), axis) for dtype in dtypes
|
||||||
|
for op in ['min', 'max', 'sum']
|
||||||
for axis in [1]
|
for axis in [1]
|
||||||
]
|
]
|
||||||
reduce_configs2 = [
|
reduce_configs2 = [
|
||||||
('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
(op, 'float32', shape, 1)
|
||||||
|
for op in ['min', 'max', 'sum']
|
||||||
|
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||||
range_m = tl.arange(0, BLOCK_M)
|
range_m = tl.arange(0, BLOCK_M)
|
||||||
range_n = tl.arange(0, BLOCK_N)
|
range_n = tl.arange(0, BLOCK_N)
|
||||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||||
z = tl.sum(x, axis=AXIS)
|
z = GENERATE_TEST_HERE
|
||||||
tl.store(Z + range_m, z)
|
tl.store(Z + range_m, z)
|
||||||
|
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||||
# input
|
# input
|
||||||
x = numpy_random(shape, dtype_str=dtype_str)
|
rs = RandomState(17)
|
||||||
# triton result
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||||
x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device)
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
# numpy result
|
||||||
# numpy reference result
|
z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str))
|
||||||
z_ref = np.sum(x, axis=axis).astype(x.dtype)
|
# triton result
|
||||||
|
z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device)
|
||||||
|
binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||||
# compare
|
# compare
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
if op == 'sum':
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
else:
|
||||||
|
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test permute
|
# test permute
|
||||||
|
@@ -136,6 +136,9 @@ class dtype:
|
|||||||
def is_int_signed(self):
|
def is_int_signed(self):
|
||||||
return self.name in dtype.SINT_TYPES
|
return self.name in dtype.SINT_TYPES
|
||||||
|
|
||||||
|
def is_int_unsigned(self):
|
||||||
|
return self.name in dtype.UINT_TYPES
|
||||||
|
|
||||||
def is_int(self):
|
def is_int(self):
|
||||||
return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
|
return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
|
||||||
|
|
||||||
|
@@ -959,6 +959,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
||||||
input = cast(input, tl.int32, builder)
|
input = cast(input, tl.int32, builder)
|
||||||
|
|
||||||
|
# choose the right unsigned operation
|
||||||
|
if scalar_ty.is_int_unsigned():
|
||||||
|
if INT_OP is ir.REDUCE_OP.MIN:
|
||||||
|
INT_OP = ir.REDUCE_OP.UMIN
|
||||||
|
elif INT_OP is ir.REDUCE_OP.MAX:
|
||||||
|
INT_OP = ir.REDUCE_OP.UMAX
|
||||||
|
|
||||||
# get result type
|
# get result type
|
||||||
shape = input.type.shape
|
shape = input.type.shape
|
||||||
ret_shape = []
|
ret_shape = []
|
||||||
|
Reference in New Issue
Block a user