[Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826)

This commit is contained in:
Qingyi Liu
2022-11-01 17:42:59 +08:00
committed by GitHub
parent 031c2ae77b
commit cdc0ec5077
5 changed files with 208 additions and 148 deletions

View File

@@ -5,11 +5,20 @@ from torch.testing import assert_close
import triton
import triton.language as tl
dtype_mapping = {
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
}
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
def get_reduced_dtype(dtype):
if dtype in [torch.int8, torch.int16, torch.uint8]:
return torch.int32
if dtype in [torch.bfloat16]:
return torch.float32
return dtype
def patch_kernel(template, to_replace):
@@ -40,7 +49,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
reduce1d_configs = [
(op, dtype, shape)
for op in ['sum', 'min', 'max']
for dtype in ['float16', 'float32', 'float64']
for dtype in dtypes
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
]
@@ -48,11 +57,18 @@ reduce1d_configs = [
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
def test_reduce1d(op, dtype, shape):
dtype = dtype_mapping[dtype]
x = torch.randn((shape,), device='cuda', dtype=dtype)
reduced_dtype = get_reduced_dtype(dtype)
if dtype.is_floating_point:
x = torch.randn((shape,), device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
z = torch.empty(
tuple(),
device=x.device,
dtype=dtype,
dtype=reduced_dtype,
)
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
@@ -60,13 +76,13 @@ def test_reduce1d(op, dtype, shape):
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
if op == 'sum':
golden_z = torch.sum(x, dtype=dtype)
golden_z = torch.sum(x, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x)
golden_z = torch.min(x).to(reduced_dtype)
else:
golden_z = torch.max(x)
golden_z = torch.max(x).to(reduced_dtype)
if op == 'sum':
if dtype.is_floating_point and op == 'sum':
if shape >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape >= 32:
@@ -80,7 +96,7 @@ def test_reduce1d(op, dtype, shape):
reduce2d_configs = [
(op, dtype, shape, axis)
for op in ['sum', 'min', 'max']
for dtype in ['float16', 'float32', 'float64']
for dtype in dtypes
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
for axis in [0, 1]
]
@@ -89,22 +105,29 @@ reduce2d_configs = [
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
def test_reduce2d(op, dtype, shape, axis):
dtype = dtype_mapping[dtype]
x = torch.randn(shape, device='cuda', dtype=dtype)
reduced_dtype = get_reduced_dtype(dtype)
reduced_shape = (shape[1 - axis],)
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
if dtype.is_floating_point:
x = torch.randn(shape, device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
grid = (1,)
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
if op == 'sum':
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype)
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x, dim=axis, keepdim=False)[0]
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
else:
golden_z = torch.max(x, dim=axis, keepdim=False)[0]
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
if op == 'sum':
if dtype.is_floating_point and op == 'sum':
if shape[axis] >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape[axis] >= 32: