Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,56 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -1,5 +1,6 @@
# flake8: noqa: F821,F841
import itertools
import os
import re
from typing import Optional, Union
@@ -104,8 +105,8 @@ def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
cc = torch.cuda.get_device_capability()
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@@ -414,8 +415,8 @@ def test_where(dtype):
def test_where_broadcast():
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
@@ -424,8 +425,8 @@ def test_where_broadcast():
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = 0
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
@@ -462,9 +463,6 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# ----------------
# test math ops
# ----------------
# @pytest.mark.parametrize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
@@ -490,9 +488,13 @@ def make_ptr_str(name, shape):
return f"{name} + {' + '.join(offsets)}"
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for s in ['None, :', ':, None']
# FIXME: 3d indexing doesn't work
#'None, :, :',
# ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
@@ -605,8 +607,8 @@ def test_tuples():
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5
@@ -651,9 +653,10 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
@pytest.mark.parametrize("shape, axis",
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
shape0, shape1 = shape
# triton kernel
@triton.jit
@@ -662,14 +665,18 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
else:
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# reference result
z_ref = np.sum(x, axis=axis)
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
@@ -724,6 +731,10 @@ def test_atomic_cas():
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# bfloat16 on cc < 80 will not be tested
check_type_supported(dtype_x)
check_type_supported(dtype_z)
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x in float_dtypes and dtype_z == 'int1':
@@ -737,9 +748,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
x_ptr = X + tl.arange(0, 1)
z_ptr = Z + tl.arange(0, 1)
x = tl.load(x_ptr)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
tl.store(z_ptr, z)
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result
@@ -869,9 +882,19 @@ def test_f16_to_f8_rounding():
# ---------------
def get_reduced_dtype(dtype_str, op):
if op == 'argmin' or op == 'argmax':
return 'int32'
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
return 'int32'
if dtype_str == 'bfloat16':
return 'float32'
return dtype_str
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for op in ['min', 'max', 'sum']
for dtype in dtypes_with_bfloat16
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
@@ -892,7 +915,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
@@ -919,21 +942,35 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
np.testing.assert_equal(z_ref, z_tri)
# TODO: [Qingyi] Fix argmin / argmax
reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for op in ['min', 'max', 'sum']
for axis in [1]
]
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
# exceeds the limit of 99KB
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
# TODO: fix and uncomment
# , (32, 64), (64, 128)]
if 'V100' in torch.cuda.get_device_name(0):
reduce2d_shapes += [(128, 256) and (32, 1024)]
reduce_configs2 = [
(op, 'float32', shape, axis)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
for op in ['min', 'max', 'sum']
for shape in reduce2d_shapes
for axis in [0, 1]
]
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -954,7 +991,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
x_tri = to_triton(x)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
# numpy result
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
@@ -992,7 +1029,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
for dtype in ['bfloat16', 'float16', 'float32']
# TODO: bfloat16
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -1038,25 +1076,37 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
for shape_nw in [[128, 256, 32, 8],
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
[64, 128, 128, 4]]
for allow_tf32 in [True]
for col_a in [True, False]
for col_b in [True, False]
for dtype in ['int8', 'float16', 'float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80:
if capability[0] < 8:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
M, N, K = 128, 128, 64
num_warps = 8
trans_a, trans_b = False, False
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
# triton kernel
@triton.jit
@@ -1068,7 +1118,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
COL_A: tl.constexpr, COL_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
@@ -1077,7 +1127,9 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
x = tl.load(Xs)
y = tl.load(Ys)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -1093,16 +1145,24 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
tl.store(Zs, z)
# input
rs = RandomState(17)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
if col_a:
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
else:
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
if col_b:
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
else:
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
if 'int' not in dtype:
x *= .1
y *= .1
if dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
@@ -1110,7 +1170,11 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
if dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
else:
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
@@ -1118,7 +1182,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
@@ -1128,9 +1192,12 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
y.astype(np.float32())).astype(np.int32)
else:
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
@@ -1142,35 +1209,39 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# FIXME: Unsupported layout found in ConvertSplatLikeOp
# def test_dot_without_load():
# @triton.jit
# def kernel(out):
# pid = tl.program_id(axis=0)
# a = tl.zeros((32, 32), tl.float32)
# b = tl.zeros((32, 32), tl.float32)
# c = tl.zeros((32, 32), tl.float32)
# c = tl.dot(a, b)
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(pout, c)
#
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
# kernel[(1,)](out)
# ---------------
# test arange
@@ -1216,7 +1287,7 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
@@ -1227,16 +1298,12 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32
@@ -1325,6 +1392,7 @@ def test_vectorization(N):
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store
# ---------------
@@ -1402,6 +1470,10 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
JITFunction.cache_hook = None
assert spec_type == value_type
# --------------------
# value specialization
# --------------------
@pytest.mark.parametrize(
"value, overflow",
@@ -1552,9 +1624,23 @@ def test_num_warps_pow2():
# -------------
def system_libdevice_path() -> str:
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
]
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
if os.path.exists(_p):
SYSTEM_LIBDEVICE_PATH = _p
assert SYSTEM_LIBDEVICE_PATH is not None, \
"Could not find libdevice.10.bc path"
return SYSTEM_LIBDEVICE_PATH
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float32', 'libdevice.pow', system_libdevice_path()),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):
@@ -1621,3 +1707,95 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
# -----------------------
# test layout conversions
# -----------------------
# TODO: backend hsould be tested separately
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}
}
"""
x = to_triton(numpy_random(shape, dtype_str=dtype))
z = torch.empty_like(x)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
assert torch.equal(z, x)

View File

@@ -1,261 +0,0 @@
# flake8: noqa: F821,F841
import random
import torch
import triton
import triton.language as tl
@triton.jit
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int8(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int4(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int2(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
def test_dequantize_int8() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 128, 4)
else:
size = random.randrange(132, 1024, 4)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input = torch.cat((scale_shift, input_int32))
expected = (input_int8 * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 128)
grid = (1,)
dequantize_kernel_int8[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int8[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int4() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size // 2,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input_int8_h1 = input_int8 >> 4
input_int8_h0 = input_int8 & 15
input_int4_val = torch.stack(
(input_int8_h0, input_int8_h1), dim=1
).flatten()
input = torch.cat((scale_shift, input_int32))
expected = (input_int4_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int4[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int4[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int2() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int16)
input_int8 = torch.randint(
0, 256, (size // 4,), dtype=torch.uint8, device=device
)
input_int16 = input_int8.view(torch.int16)
input_int8_q3 = input_int8 >> 6
input_int8_q2 = (input_int8 >> 4) & 3
input_int8_q1 = (input_int8 >> 2) & 3
input_int8_q0 = input_int8 & 3
input_int2_val = torch.stack(
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
).flatten()
input = torch.cat((scale_shift, input_int16))
expected = (input_int2_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int2[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int2[grid](
output,
input_int16,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)

View File

@@ -0,0 +1,22 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -2,13 +2,13 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
# TODO: float32 fails
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
@@ -32,9 +32,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
@@ -126,8 +126,8 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs

View File

@@ -2,20 +2,19 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['bfloat16', 'float16', 'float32']
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and dtype == "bfloat16":
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs

View File

@@ -4,7 +4,6 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
@@ -67,10 +66,10 @@ import triton._C.libtriton.triton as _triton
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16":
if capability[0] < 8 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")