165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
import pytest
|
|
import numpy as np
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
|
|
float_dtypes = ['float16', 'float32', 'float64']
|
|
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
|
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
|
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
|
|
|
|
|
def get_reduced_dtype(op, dtype):
|
|
if op in ['argmin', 'argmax']:
|
|
return torch.int32
|
|
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
|
return torch.int32
|
|
if dtype in [torch.bfloat16]:
|
|
return torch.float32
|
|
return dtype
|
|
|
|
|
|
def patch_kernel(template, to_replace):
|
|
kernel = triton.JITFunction(template.fn)
|
|
for key, value in to_replace.items():
|
|
kernel.src = kernel.src.replace(key, value)
|
|
return kernel
|
|
|
|
|
|
@triton.jit
|
|
def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr):
|
|
x = tl.load(x_ptr + tl.arange(0, block))
|
|
tl.store(z_ptr, tl.OP(x, axis=0))
|
|
|
|
|
|
@triton.jit
|
|
def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr):
|
|
range_m = tl.arange(0, block_m)
|
|
range_n = tl.arange(0, block_n)
|
|
x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :])
|
|
z = tl.OP(x, axis=axis)
|
|
if axis == 0:
|
|
tl.store(z_ptr + range_n, z)
|
|
else:
|
|
tl.store(z_ptr + range_m, z)
|
|
|
|
|
|
reduce1d_configs = [
|
|
(op, dtype, shape)
|
|
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
|
for dtype in dtypes
|
|
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
|
def test_reduce1d(op, dtype, shape):
|
|
if op == 'xor_sum' and dtype in float_dtypes:
|
|
return
|
|
|
|
dtype = dtype_mapping[dtype]
|
|
reduced_dtype = get_reduced_dtype(op, dtype)
|
|
|
|
if dtype.is_floating_point:
|
|
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
|
elif dtype is torch.uint8:
|
|
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
|
|
else:
|
|
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
|
|
z = torch.empty(
|
|
tuple(),
|
|
device=x.device,
|
|
dtype=reduced_dtype,
|
|
)
|
|
|
|
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
|
grid = (1,)
|
|
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
|
|
|
if op == 'sum':
|
|
golden_z = torch.sum(x, dtype=reduced_dtype)
|
|
elif op == 'min':
|
|
golden_z = torch.min(x).to(reduced_dtype)
|
|
elif op == 'max':
|
|
golden_z = torch.max(x).to(reduced_dtype)
|
|
elif op == 'argmin':
|
|
golden_z = torch.argmin(x).to(reduced_dtype)
|
|
elif op == 'argmax':
|
|
golden_z = torch.argmax(x).to(reduced_dtype)
|
|
elif op == 'xor_sum':
|
|
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
|
|
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
|
else:
|
|
raise RuntimeError(f'Unknwon reduce op {op}')
|
|
|
|
if dtype.is_floating_point and op == 'sum':
|
|
if shape >= 256:
|
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
|
elif shape >= 32:
|
|
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
|
else:
|
|
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
|
else:
|
|
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
|
|
|
|
|
reduce2d_configs = [
|
|
(op, dtype, shape, axis)
|
|
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
|
for dtype in dtypes
|
|
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
|
for axis in [0, 1]
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
|
def test_reduce2d(op, dtype, shape, axis):
|
|
if op == 'xor_sum' and dtype in float_dtypes:
|
|
return
|
|
|
|
dtype = dtype_mapping[dtype]
|
|
reduced_dtype = get_reduced_dtype(op, dtype)
|
|
reduced_shape = (shape[1 - axis],)
|
|
|
|
if dtype.is_floating_point:
|
|
x = torch.randn(shape, device='cuda', dtype=dtype)
|
|
elif dtype is torch.uint8:
|
|
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
|
|
else:
|
|
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
|
|
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
|
|
|
|
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
|
kernel[(1,)](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
|
|
|
if op == 'sum':
|
|
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
|
elif op == 'min':
|
|
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
|
elif op == 'max':
|
|
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
|
elif op == 'argmin':
|
|
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
|
|
elif op == 'argmax':
|
|
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
|
|
elif op == 'xor_sum':
|
|
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
|
|
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
|
else:
|
|
raise RuntimeError(f'Unknwon reduce op {op}')
|
|
|
|
if dtype.is_floating_point and op == 'sum':
|
|
if shape[axis] >= 256:
|
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
|
elif shape[axis] >= 32:
|
|
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
|
else:
|
|
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
|
else:
|
|
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|