[Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826)
This commit is contained in:
@@ -5,11 +5,20 @@ from torch.testing import assert_close
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
dtype_mapping = {
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
}
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype):
|
||||
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
return torch.int32
|
||||
if dtype in [torch.bfloat16]:
|
||||
return torch.float32
|
||||
return dtype
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
@@ -40,7 +49,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
|
||||
reduce1d_configs = [
|
||||
(op, dtype, shape)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in ['float16', 'float32', 'float64']
|
||||
for dtype in dtypes
|
||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||
]
|
||||
|
||||
@@ -48,11 +57,18 @@ reduce1d_configs = [
|
||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||
def test_reduce1d(op, dtype, shape):
|
||||
dtype = dtype_mapping[dtype]
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
elif dtype is torch.uint8:
|
||||
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
|
||||
else:
|
||||
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
|
||||
z = torch.empty(
|
||||
tuple(),
|
||||
device=x.device,
|
||||
dtype=dtype,
|
||||
dtype=reduced_dtype,
|
||||
)
|
||||
|
||||
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
||||
@@ -60,13 +76,13 @@ def test_reduce1d(op, dtype, shape):
|
||||
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dtype=dtype)
|
||||
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x)
|
||||
golden_z = torch.min(x).to(reduced_dtype)
|
||||
else:
|
||||
golden_z = torch.max(x)
|
||||
golden_z = torch.max(x).to(reduced_dtype)
|
||||
|
||||
if op == 'sum':
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape >= 32:
|
||||
@@ -80,7 +96,7 @@ def test_reduce1d(op, dtype, shape):
|
||||
reduce2d_configs = [
|
||||
(op, dtype, shape, axis)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in ['float16', 'float32', 'float64']
|
||||
for dtype in dtypes
|
||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||
for axis in [0, 1]
|
||||
]
|
||||
@@ -89,22 +105,29 @@ reduce2d_configs = [
|
||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||
def test_reduce2d(op, dtype, shape, axis):
|
||||
dtype = dtype_mapping[dtype]
|
||||
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_shape = (shape[1 - axis],)
|
||||
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||
elif dtype is torch.uint8:
|
||||
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
|
||||
else:
|
||||
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
|
||||
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
|
||||
|
||||
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
||||
grid = (1,)
|
||||
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype)
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0]
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
else:
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0]
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
|
||||
if op == 'sum':
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape[axis] >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape[axis] >= 32:
|
||||
|
Reference in New Issue
Block a user