180 lines
5.6 KiB
Python
180 lines
5.6 KiB
Python
|
|
import pytest
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
from tests.libdevice_testutil import system_libdevice_path
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
|
[4, 256, 1],
|
|
[4, 1024, 256],
|
|
])
|
|
def test_sin_no_mask(num_warps, block_size, iter_size):
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
block_size,
|
|
iter_size: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
for i in range(0, block_size, iter_size):
|
|
offset = pid * block_size + tl.arange(0, iter_size)
|
|
x_ptrs = x_ptr + offset
|
|
x = tl.load(x_ptrs)
|
|
y = tl.libdevice.sin(x)
|
|
y_ptrs = y_ptr + offset
|
|
tl.store(y_ptrs, y)
|
|
|
|
x_ptr += iter_size
|
|
y_ptr += iter_size
|
|
|
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
y = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y,
|
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
|
|
|
golden_y = torch.sin(x)
|
|
assert_close(y, golden_y, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
|
[4, 256, 1],
|
|
[4, 1024, 256],
|
|
])
|
|
def test_fmin_no_mask(num_warps, block_size, iter_size):
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
block_size,
|
|
iter_size: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
for i in range(0, block_size, iter_size):
|
|
offset = pid * block_size + tl.arange(0, iter_size)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
|
|
x = tl.load(x_ptrs)
|
|
y = tl.load(y_ptrs)
|
|
z = tl.libdevice.min(x, y)
|
|
z_ptrs = z_ptr + offset
|
|
tl.store(z_ptrs, z)
|
|
|
|
x_ptr += iter_size
|
|
y_ptr += iter_size
|
|
z_ptr += iter_size
|
|
|
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
|
|
|
golden_z = torch.minimum(x, y)
|
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
|
[4, 256, 1],
|
|
[4, 1024, 256],
|
|
])
|
|
def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
w_ptr,
|
|
block_size,
|
|
iter_size: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
for i in range(0, block_size, iter_size):
|
|
offset = pid * block_size + tl.arange(0, iter_size)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
z_ptrs = z_ptr + offset
|
|
|
|
x = tl.load(x_ptrs)
|
|
y = tl.load(y_ptrs)
|
|
z = tl.load(z_ptrs)
|
|
|
|
w = tl.libdevice.fma_rn(x, y, z)
|
|
w_ptrs = w_ptr + offset
|
|
tl.store(w_ptrs, w)
|
|
|
|
x_ptr += iter_size
|
|
y_ptr += iter_size
|
|
z_ptr += iter_size
|
|
w_ptr += iter_size
|
|
|
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
|
z = torch.randn((block_size,), device='cuda', dtype=torch.float64)
|
|
w = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w,
|
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
|
|
|
golden_w = x * y + z
|
|
assert_close(w, golden_w, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
|
[('int32', 'libdevice.ffs', system_libdevice_path()),
|
|
('int32', 'libdevice.ffs', '')])
|
|
def test_libdevice(dtype_str, expr, lib_path):
|
|
src = f"""
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = tl.{expr}(x)
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
"""
|
|
import tempfile
|
|
from inspect import Parameter, Signature
|
|
|
|
import _testcapi
|
|
|
|
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
|
|
fp.write(src)
|
|
fp.flush()
|
|
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
pass
|
|
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
|
|
parameters = []
|
|
parameters.append(Parameter("X", 1))
|
|
parameters.append(Parameter("Y", 1))
|
|
parameters.append(Parameter("BLOCK", 1))
|
|
kernel.__signature__ = Signature(parameters=parameters)
|
|
kernel = triton.jit(kernel)
|
|
|
|
torch_type = {
|
|
"int32": torch.int32,
|
|
"float32": torch.float32,
|
|
"float64": torch.float64
|
|
}
|
|
|
|
shape = (128, )
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = None
|
|
if dtype_str == "int32":
|
|
x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda")
|
|
else:
|
|
x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda")
|
|
if expr == 'libdevice.ffs':
|
|
y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
|
for i in range(shape[0]):
|
|
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
|
|
|
# triton result
|
|
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
|
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
|
|
# compare
|
|
assert_close(y, y_ref)
|