[Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826)
This commit is contained in:
@@ -29,12 +29,20 @@ def TT_RedOpAttr : I32EnumAttr<
|
|||||||
/*case*/
|
/*case*/
|
||||||
[
|
[
|
||||||
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||||
I32EnumAttrCase<"MAX", 2, "max">,
|
I32EnumAttrCase<"FADD", 2, "fadd">,
|
||||||
I32EnumAttrCase<"MIN", 3, "min">,
|
I32EnumAttrCase<"MIN", 3, "min">,
|
||||||
I32EnumAttrCase<"FADD", 4, "fadd">,
|
I32EnumAttrCase<"MAX", 4, "max">,
|
||||||
I32EnumAttrCase<"FMAX", 5, "fmax">,
|
I32EnumAttrCase<"UMIN", 5, "umin">,
|
||||||
I32EnumAttrCase<"FMIN", 6, "fmin">,
|
I32EnumAttrCase<"UMAX", 6, "umax">,
|
||||||
I32EnumAttrCase<"XOR", 7, "xor">
|
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
|
||||||
|
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
|
||||||
|
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
|
||||||
|
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
|
||||||
|
I32EnumAttrCase<"FMIN", 11, "fmin">,
|
||||||
|
I32EnumAttrCase<"FMAX", 12, "fmax">,
|
||||||
|
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
|
||||||
|
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
|
||||||
|
I32EnumAttrCase<"XOR", 15, "xor">
|
||||||
]> {
|
]> {
|
||||||
let cppNamespace = "::mlir::triton";
|
let cppNamespace = "::mlir::triton";
|
||||||
}
|
}
|
||||||
|
@@ -1301,27 +1301,27 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
|||||||
case RedOp::ADD:
|
case RedOp::ADD:
|
||||||
acc = add(acc, cur);
|
acc = add(acc, cur);
|
||||||
break;
|
break;
|
||||||
case RedOp::MAX:
|
|
||||||
if (type.isUnsignedInteger())
|
|
||||||
acc = umax(acc, cur);
|
|
||||||
else
|
|
||||||
acc = smax(acc, cur);
|
|
||||||
break;
|
|
||||||
case RedOp::MIN:
|
|
||||||
if (type.isUnsignedInteger())
|
|
||||||
acc = umin(acc, cur);
|
|
||||||
else
|
|
||||||
acc = smin(acc, cur);
|
|
||||||
break;
|
|
||||||
case RedOp::FADD:
|
case RedOp::FADD:
|
||||||
acc = fadd(acc.getType(), acc, cur);
|
acc = fadd(acc.getType(), acc, cur);
|
||||||
break;
|
break;
|
||||||
case RedOp::FMAX:
|
case RedOp::MIN:
|
||||||
acc = fmax(acc, cur);
|
acc = smin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::MAX:
|
||||||
|
acc = smax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::UMIN:
|
||||||
|
acc = umin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::UMAX:
|
||||||
|
acc = umax(acc, cur);
|
||||||
break;
|
break;
|
||||||
case RedOp::FMIN:
|
case RedOp::FMIN:
|
||||||
acc = fmin(acc, cur);
|
acc = fmin(acc, cur);
|
||||||
break;
|
break;
|
||||||
|
case RedOp::FMAX:
|
||||||
|
acc = fmax(acc, cur);
|
||||||
|
break;
|
||||||
case RedOp::XOR:
|
case RedOp::XOR:
|
||||||
acc = xor_(acc, cur);
|
acc = xor_(acc, cur);
|
||||||
break;
|
break;
|
||||||
|
@@ -87,8 +87,16 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.value("FADD", mlir::triton::RedOp::FADD)
|
.value("FADD", mlir::triton::RedOp::FADD)
|
||||||
.value("MIN", mlir::triton::RedOp::MIN)
|
.value("MIN", mlir::triton::RedOp::MIN)
|
||||||
.value("MAX", mlir::triton::RedOp::MAX)
|
.value("MAX", mlir::triton::RedOp::MAX)
|
||||||
|
.value("UMIN", mlir::triton::RedOp::UMIN)
|
||||||
|
.value("UMAX", mlir::triton::RedOp::UMAX)
|
||||||
|
.value("ARGMIN", mlir::triton::RedOp::ARGMIN)
|
||||||
|
.value("ARGMAX", mlir::triton::RedOp::ARGMAX)
|
||||||
|
.value("ARGUMIN", mlir::triton::RedOp::ARGUMIN)
|
||||||
|
.value("ARGUMAX", mlir::triton::RedOp::ARGUMAX)
|
||||||
.value("FMIN", mlir::triton::RedOp::FMIN)
|
.value("FMIN", mlir::triton::RedOp::FMIN)
|
||||||
.value("FMAX", mlir::triton::RedOp::FMAX)
|
.value("FMAX", mlir::triton::RedOp::FMAX)
|
||||||
|
.value("ARGFMIN", mlir::triton::RedOp::ARGFMIN)
|
||||||
|
.value("ARGFMAX", mlir::triton::RedOp::ARGFMAX)
|
||||||
.value("XOR", mlir::triton::RedOp::XOR);
|
.value("XOR", mlir::triton::RedOp::XOR);
|
||||||
|
|
||||||
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
|
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
|
||||||
|
@@ -870,121 +870,142 @@ def test_store_bool():
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("op, dtype_str, shape",
|
def get_reduced_dtype(dtype_str, op):
|
||||||
# [(op, dtype, shape)
|
if op == 'argmin' or op == 'argmax':
|
||||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
return 'int32'
|
||||||
# for dtype in dtypes_with_bfloat16
|
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||||
# for shape in [32, 64, 128, 512]])
|
return 'int32'
|
||||||
# def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
if dtype_str == 'bfloat16':
|
||||||
# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
return 'float32'
|
||||||
|
return dtype_str
|
||||||
# # triton kernel
|
|
||||||
# @triton.jit
|
|
||||||
# def kernel(X, Z, BLOCK: tl.constexpr):
|
|
||||||
# x = tl.load(X + tl.arange(0, BLOCK))
|
|
||||||
# tl.store(Z, GENERATE_TEST_HERE)
|
|
||||||
|
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
|
||||||
# # input
|
|
||||||
# rs = RandomState(17)
|
|
||||||
# # limit the range of integers so that the sum does not overflow
|
|
||||||
# x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
|
||||||
# x_tri = to_triton(x, device=device)
|
|
||||||
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
|
||||||
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
|
|
||||||
# # numpy result
|
|
||||||
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
|
||||||
# z_tri_dtype_str = z_dtype_str
|
|
||||||
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
|
||||||
# z_dtype_str = 'float32'
|
|
||||||
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
|
||||||
# # trunc mantissa for a fair comparison of accuracy
|
|
||||||
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
|
||||||
# z_tri_dtype_str = 'bfloat16'
|
|
||||||
# else:
|
|
||||||
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
|
||||||
# # triton result
|
|
||||||
# z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
|
||||||
# device=device, dst_type=z_tri_dtype_str)
|
|
||||||
# kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
|
||||||
# z_tri = to_numpy(z_tri)
|
|
||||||
# # compare
|
|
||||||
# if op == 'sum':
|
|
||||||
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
|
||||||
# else:
|
|
||||||
# if op == 'argmin' or op == 'argmax':
|
|
||||||
# # argmin and argmax can have multiple valid indices.
|
|
||||||
# # so instead we compare the values pointed by indices
|
|
||||||
# np.testing.assert_equal(x[z_ref], x[z_tri])
|
|
||||||
# else:
|
|
||||||
# np.testing.assert_equal(z_ref, z_tri)
|
|
||||||
|
|
||||||
|
|
||||||
# reduce_configs1 = [
|
# TODO: [Qingyi] Fix argmin / argmax
|
||||||
# (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
[(op, dtype, shape)
|
||||||
# for axis in [1]
|
for op in ['min', 'max', 'sum']
|
||||||
# ]
|
for dtype in dtypes_with_bfloat16
|
||||||
# reduce_configs2 = [
|
for shape in [32, 64, 128, 512]])
|
||||||
# (op, 'float32', shape, axis)
|
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||||
# for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
|
||||||
# for axis in [0, 1]
|
# triton kernel
|
||||||
# ]
|
@triton.jit
|
||||||
|
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
tl.store(Z, GENERATE_TEST_HERE)
|
||||||
|
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
||||||
|
# input
|
||||||
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||||
|
x_tri = to_triton(x, device=device)
|
||||||
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||||
|
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||||
|
# numpy result
|
||||||
|
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||||
|
z_tri_dtype_str = z_dtype_str
|
||||||
|
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||||
|
z_dtype_str = 'float32'
|
||||||
|
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||||
|
# trunc mantissa for a fair comparison of accuracy
|
||||||
|
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||||
|
z_tri_dtype_str = 'bfloat16'
|
||||||
|
else:
|
||||||
|
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||||
|
# triton result
|
||||||
|
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
||||||
|
device=device, dst_type=z_tri_dtype_str)
|
||||||
|
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||||
|
z_tri = to_numpy(z_tri)
|
||||||
|
# compare
|
||||||
|
if op == 'sum':
|
||||||
|
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||||
|
else:
|
||||||
|
if op == 'argmin' or op == 'argmax':
|
||||||
|
# argmin and argmax can have multiple valid indices.
|
||||||
|
# so instead we compare the values pointed by indices
|
||||||
|
np.testing.assert_equal(x[z_ref], x[z_tri])
|
||||||
|
else:
|
||||||
|
np.testing.assert_equal(z_ref, z_tri)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
# TODO: [Qingyi] Fix argmin / argmax
|
||||||
# def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
reduce_configs1 = [
|
||||||
# # triton kernel
|
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||||
# @triton.jit
|
for op in ['min', 'max', 'sum']
|
||||||
# def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
for axis in [1]
|
||||||
# range_m = tl.arange(0, BLOCK_M)
|
]
|
||||||
# range_n = tl.arange(0, BLOCK_N)
|
|
||||||
# x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
|
||||||
# z = GENERATE_TEST_HERE
|
|
||||||
# if AXIS == 1:
|
|
||||||
# tl.store(Z + range_m, z)
|
|
||||||
# else:
|
|
||||||
# tl.store(Z + range_n, z)
|
|
||||||
|
|
||||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
|
||||||
# # input
|
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
|
||||||
# rs = RandomState(17)
|
# exceeds the limit of 99KB
|
||||||
# # limit the range of integers so that the sum does not overflow
|
reduce2d_shapes = [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128)]
|
||||||
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
if 'V100' in torch.cuda.get_device_name(0):
|
||||||
# x_tri = to_triton(x)
|
reduce2d_shapes += [(128, 256) and (32, 1024)]
|
||||||
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
|
||||||
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
|
|
||||||
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
reduce_configs2 = [
|
||||||
# z_tri_dtype_str = z_dtype_str
|
(op, 'float32', shape, axis)
|
||||||
# # numpy result
|
for op in ['min', 'max', 'sum']
|
||||||
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
for shape in reduce2d_shapes
|
||||||
# z_dtype_str = 'float32'
|
for axis in [0, 1]
|
||||||
# z_tri_dtype_str = 'bfloat16'
|
]
|
||||||
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
|
||||||
# # trunc mantissa for a fair comparison of accuracy
|
|
||||||
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||||
# else:
|
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||||
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
# triton kernel
|
||||||
# # triton result
|
@triton.jit
|
||||||
# z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||||
# device=device, dst_type=z_tri_dtype_str)
|
range_m = tl.arange(0, BLOCK_M)
|
||||||
# kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
range_n = tl.arange(0, BLOCK_N)
|
||||||
# z_tri = to_numpy(z_tri)
|
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||||
# # compare
|
z = GENERATE_TEST_HERE
|
||||||
# if op == 'sum':
|
if AXIS == 1:
|
||||||
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
tl.store(Z + range_m, z)
|
||||||
# else:
|
else:
|
||||||
# if op == 'argmin' or op == 'argmax':
|
tl.store(Z + range_n, z)
|
||||||
# # argmin and argmax can have multiple valid indices.
|
|
||||||
# # so instead we compare the values pointed by indices
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||||
# z_ref_index = np.expand_dims(z_ref, axis=axis)
|
# input
|
||||||
# z_tri_index = np.expand_dims(z_tri, axis=axis)
|
rs = RandomState(17)
|
||||||
# z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
# limit the range of integers so that the sum does not overflow
|
||||||
# z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||||
# np.testing.assert_equal(z_ref_value, z_tri_value)
|
x_tri = to_triton(x)
|
||||||
# else:
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||||
# np.testing.assert_equal(z_ref, z_tri)
|
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||||
|
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||||
|
z_tri_dtype_str = z_dtype_str
|
||||||
|
# numpy result
|
||||||
|
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||||
|
z_dtype_str = 'float32'
|
||||||
|
z_tri_dtype_str = 'bfloat16'
|
||||||
|
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||||
|
# trunc mantissa for a fair comparison of accuracy
|
||||||
|
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||||
|
else:
|
||||||
|
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||||
|
# triton result
|
||||||
|
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||||
|
device=device, dst_type=z_tri_dtype_str)
|
||||||
|
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||||
|
z_tri = to_numpy(z_tri)
|
||||||
|
# compare
|
||||||
|
if op == 'sum':
|
||||||
|
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||||
|
else:
|
||||||
|
if op == 'argmin' or op == 'argmax':
|
||||||
|
# argmin and argmax can have multiple valid indices.
|
||||||
|
# so instead we compare the values pointed by indices
|
||||||
|
z_ref_index = np.expand_dims(z_ref, axis=axis)
|
||||||
|
z_tri_index = np.expand_dims(z_tri, axis=axis)
|
||||||
|
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
||||||
|
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
||||||
|
np.testing.assert_equal(z_ref_value, z_tri_value)
|
||||||
|
else:
|
||||||
|
np.testing.assert_equal(z_ref, z_tri)
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test permute
|
# # test permute
|
||||||
|
@@ -5,11 +5,20 @@ from torch.testing import assert_close
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
dtype_mapping = {
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||||
'float16': torch.float16,
|
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
|
||||||
'float32': torch.float32,
|
float_dtypes = ['float16', 'float32', 'float64']
|
||||||
'float64': torch.float64,
|
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||||
}
|
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
||||||
|
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||||
|
|
||||||
|
|
||||||
|
def get_reduced_dtype(dtype):
|
||||||
|
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||||
|
return torch.int32
|
||||||
|
if dtype in [torch.bfloat16]:
|
||||||
|
return torch.float32
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
def patch_kernel(template, to_replace):
|
def patch_kernel(template, to_replace):
|
||||||
@@ -40,7 +49,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
|
|||||||
reduce1d_configs = [
|
reduce1d_configs = [
|
||||||
(op, dtype, shape)
|
(op, dtype, shape)
|
||||||
for op in ['sum', 'min', 'max']
|
for op in ['sum', 'min', 'max']
|
||||||
for dtype in ['float16', 'float32', 'float64']
|
for dtype in dtypes
|
||||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -48,11 +57,18 @@ reduce1d_configs = [
|
|||||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||||
def test_reduce1d(op, dtype, shape):
|
def test_reduce1d(op, dtype, shape):
|
||||||
dtype = dtype_mapping[dtype]
|
dtype = dtype_mapping[dtype]
|
||||||
|
reduced_dtype = get_reduced_dtype(dtype)
|
||||||
|
|
||||||
|
if dtype.is_floating_point:
|
||||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||||
|
elif dtype is torch.uint8:
|
||||||
|
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
|
||||||
|
else:
|
||||||
|
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
|
||||||
z = torch.empty(
|
z = torch.empty(
|
||||||
tuple(),
|
tuple(),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=dtype,
|
dtype=reduced_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
||||||
@@ -60,13 +76,13 @@ def test_reduce1d(op, dtype, shape):
|
|||||||
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
||||||
|
|
||||||
if op == 'sum':
|
if op == 'sum':
|
||||||
golden_z = torch.sum(x, dtype=dtype)
|
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||||
elif op == 'min':
|
elif op == 'min':
|
||||||
golden_z = torch.min(x)
|
golden_z = torch.min(x).to(reduced_dtype)
|
||||||
else:
|
else:
|
||||||
golden_z = torch.max(x)
|
golden_z = torch.max(x).to(reduced_dtype)
|
||||||
|
|
||||||
if op == 'sum':
|
if dtype.is_floating_point and op == 'sum':
|
||||||
if shape >= 256:
|
if shape >= 256:
|
||||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||||
elif shape >= 32:
|
elif shape >= 32:
|
||||||
@@ -80,7 +96,7 @@ def test_reduce1d(op, dtype, shape):
|
|||||||
reduce2d_configs = [
|
reduce2d_configs = [
|
||||||
(op, dtype, shape, axis)
|
(op, dtype, shape, axis)
|
||||||
for op in ['sum', 'min', 'max']
|
for op in ['sum', 'min', 'max']
|
||||||
for dtype in ['float16', 'float32', 'float64']
|
for dtype in dtypes
|
||||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||||
for axis in [0, 1]
|
for axis in [0, 1]
|
||||||
]
|
]
|
||||||
@@ -89,22 +105,29 @@ reduce2d_configs = [
|
|||||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||||
def test_reduce2d(op, dtype, shape, axis):
|
def test_reduce2d(op, dtype, shape, axis):
|
||||||
dtype = dtype_mapping[dtype]
|
dtype = dtype_mapping[dtype]
|
||||||
x = torch.randn(shape, device='cuda', dtype=dtype)
|
reduced_dtype = get_reduced_dtype(dtype)
|
||||||
reduced_shape = (shape[1 - axis],)
|
reduced_shape = (shape[1 - axis],)
|
||||||
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
|
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||||
|
elif dtype is torch.uint8:
|
||||||
|
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
|
||||||
|
else:
|
||||||
|
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
|
||||||
|
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
|
||||||
|
|
||||||
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
||||||
grid = (1,)
|
grid = (1,)
|
||||||
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
||||||
|
|
||||||
if op == 'sum':
|
if op == 'sum':
|
||||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype)
|
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||||
elif op == 'min':
|
elif op == 'min':
|
||||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0]
|
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||||
else:
|
else:
|
||||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0]
|
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||||
|
|
||||||
if op == 'sum':
|
if dtype.is_floating_point and op == 'sum':
|
||||||
if shape[axis] >= 256:
|
if shape[axis] >= 256:
|
||||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||||
elif shape[axis] >= 32:
|
elif shape[axis] >= 32:
|
||||||
|
Reference in New Issue
Block a user