Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
209
python/test/test_code_gen.py
Normal file
209
python/test/test_code_gen.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import triton
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# convert from string to torch.dtype
|
||||
# Necessary because doesn't print torch.dtype properly
|
||||
cvt = {
|
||||
'bool': torch.bool,
|
||||
'int8': torch.int8,
|
||||
'int16': torch.int16,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
}
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
kernel = copy.deepcopy(template)
|
||||
for key, value in to_replace.items():
|
||||
kernel.src = kernel.src.replace(key, value)
|
||||
return kernel
|
||||
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
y = triton.load(Y + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||
# reference result
|
||||
z_ref = eval(expr)
|
||||
# triton result
|
||||
z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device)
|
||||
kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4)
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['+', '-', '*', '/', '%'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['&', '|', '^'] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(RuntimeError):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test compare ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
||||
(dtype_x, dtype_y, f' x {op} y') \
|
||||
for op in ['==', '!=', '>', '<', '>=', '<='] \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test unary ops
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, f' -x') for dtype_x in float_dtypes
|
||||
] + [\
|
||||
(dtype_x, f' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test indexing
|
||||
# ----------------
|
||||
|
||||
|
||||
def make_ptr_str(name, shape):
|
||||
rank = len(shape)
|
||||
offsets = []
|
||||
stride = 1
|
||||
for i in reversed(range(rank)):
|
||||
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||
stride *= shape[i]
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [f'x[{s}]' for s in
|
||||
['None, :', ':, None',\
|
||||
'None, :, :', ':, :, None']\
|
||||
])
|
||||
def test_index1d(expr, device='cuda'):
|
||||
dtype = torch.int32
|
||||
rank_x = expr.count(':')
|
||||
rank_y = expr.count(',') + 1
|
||||
shape_x = [32 for _ in range(rank_x)]
|
||||
shape_z = [32 for _ in range(rank_y)]
|
||||
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
SIZE = meta['SIZE']
|
||||
m = triton.arange(0, SIZE)
|
||||
n = triton.arange(0, SIZE)
|
||||
x = triton.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z_PTR_EXPR, z)
|
||||
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
||||
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
||||
'GENERATE_TEST_HERE': expr,
|
||||
}
|
||||
kernel = patch_kernel(kernel, to_replace)
|
||||
|
||||
# torch result
|
||||
x = triton.testing.random(shape_x, dtype=dtype, device=device)
|
||||
y = torch.zeros(shape_z, dtype=dtype, device=device)
|
||||
z_ref = eval(expr) + y
|
||||
# triton result
|
||||
z_tri = torch.empty_like(z_ref)
|
||||
kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0])
|
||||
# compare
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test if
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test for
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test while
|
||||
# ---------------
|
@@ -1,17 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def test_op():
|
||||
torch.manual_seed(0)
|
||||
DTYPE = torch.float16
|
||||
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad, stride, = (1, 1), (1, 1)
|
||||
dilation = (1, 1)
|
||||
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
|
||||
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
|
||||
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
|
||||
tt_c = triton.ops.conv(a, b, pad, stride)
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
@@ -3,66 +3,74 @@ import itertools
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# # 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# # 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# # split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# # variable input
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
torch.manual_seed(0)
|
||||
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
|
||||
if M is None:
|
||||
M = TM
|
||||
if N is None:
|
||||
N = TN
|
||||
if K is None:
|
||||
K = TK * SPLITK
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
|
Reference in New Issue
Block a user