Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
56
python/test/unit/language/printf_helper.py
Normal file
56
python/test/unit/language/printf_helper.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
@@ -1,5 +1,6 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -104,8 +105,8 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@@ -414,8 +415,8 @@ def test_where(dtype):
|
||||
def test_where_broadcast():
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
||||
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
|
||||
mask = tl.load(cond_ptr + yoffsets)
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
@@ -424,8 +425,8 @@ def test_where_broadcast():
|
||||
|
||||
@triton.jit
|
||||
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
||||
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
mask = 0
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
res = tl.where(mask, vals, 0.)
|
||||
@@ -462,9 +463,6 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# ----------------
|
||||
# test math ops
|
||||
# ----------------
|
||||
# @pytest.mark.parametrize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
@@ -490,9 +488,13 @@ def make_ptr_str(name, shape):
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for s in ['None, :', ':, None']
|
||||
# FIXME: 3d indexing doesn't work
|
||||
#'None, :, :',
|
||||
# ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
@@ -605,8 +607,8 @@ def test_tuples():
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
if dtype_x_str == 'float16':
|
||||
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
|
||||
n_programs = 5
|
||||
@@ -651,9 +653,10 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
shape0, shape1 = 8, 8
|
||||
@pytest.mark.parametrize("shape, axis",
|
||||
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
|
||||
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
shape0, shape1 = shape
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
@@ -662,14 +665,18 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.atomic_add(Z + off0, z)
|
||||
if AXIS == 1:
|
||||
tl.atomic_add(Z + off0, z)
|
||||
else:
|
||||
tl.atomic_add(Z + off1, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis)
|
||||
z_ref = np.sum(x, axis=axis, keepdims=False)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
z_shape = (shape0, ) if axis == 1 else (shape1, )
|
||||
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
@@ -724,6 +731,10 @@ def test_atomic_cas():
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# bfloat16 on cc < 80 will not be tested
|
||||
check_type_supported(dtype_x)
|
||||
check_type_supported(dtype_z)
|
||||
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
@@ -737,9 +748,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
x_ptr = X + tl.arange(0, 1)
|
||||
z_ptr = Z + tl.arange(0, 1)
|
||||
x = tl.load(x_ptr)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
tl.store(z_ptr, z)
|
||||
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
@@ -869,9 +882,19 @@ def test_f16_to_f8_rounding():
|
||||
# ---------------
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
return 'int32'
|
||||
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||
return 'int32'
|
||||
if dtype_str == 'bfloat16':
|
||||
return 'float32'
|
||||
return dtype_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for op in ['min', 'max', 'sum']
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@@ -892,7 +915,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
@@ -919,21 +942,35 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
# TODO: [Qingyi] Fix argmin / argmax
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for op in ['min', 'max', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
|
||||
|
||||
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
|
||||
# exceeds the limit of 99KB
|
||||
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
|
||||
# TODO: fix and uncomment
|
||||
# , (32, 64), (64, 128)]
|
||||
if 'V100' in torch.cuda.get_device_name(0):
|
||||
reduce2d_shapes += [(128, 256) and (32, 1024)]
|
||||
|
||||
|
||||
reduce_configs2 = [
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
for op in ['min', 'max', 'sum']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
@@ -954,7 +991,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
x_tri = to_triton(x)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
# numpy result
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
@@ -992,7 +1029,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
@@ -1038,25 +1076,37 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
|
||||
for shape_nw in [[128, 256, 32, 8],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[128, 128, 64, 2],
|
||||
[64, 128, 128, 4]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for dtype in ['int8', 'float16', 'float32']])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80:
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
trans_a, trans_b = False, False
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -1068,7 +1118,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
COL_A: tl.constexpr, COL_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
@@ -1077,7 +1127,9 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -1093,16 +1145,24 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
if col_a:
|
||||
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
|
||||
else:
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
if col_b:
|
||||
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
|
||||
else:
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
|
||||
if 'int' not in dtype:
|
||||
x *= .1
|
||||
y *= .1
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
@@ -1110,7 +1170,11 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
if dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
else:
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
@@ -1118,7 +1182,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
@@ -1128,9 +1192,12 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
y.astype(np.float32())).astype(np.int32)
|
||||
else:
|
||||
z_ref = np.matmul(x, y)
|
||||
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
@@ -1142,35 +1209,39 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
if dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
||||
# def test_dot_without_load():
|
||||
# @triton.jit
|
||||
# def kernel(out):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# a = tl.zeros((32, 32), tl.float32)
|
||||
# b = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.dot(a, b)
|
||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(pout, c)
|
||||
#
|
||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
@@ -1216,7 +1287,7 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
@@ -1227,16 +1298,12 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
@@ -1325,6 +1392,7 @@ def test_vectorization(N):
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
@@ -1402,6 +1470,10 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
JITFunction.cache_hook = None
|
||||
assert spec_type == value_type
|
||||
|
||||
# --------------------
|
||||
# value specialization
|
||||
# --------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
@@ -1552,9 +1624,23 @@ def test_num_warps_pow2():
|
||||
# -------------
|
||||
|
||||
|
||||
def system_libdevice_path() -> str:
|
||||
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
|
||||
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
]
|
||||
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
|
||||
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
|
||||
if os.path.exists(_p):
|
||||
SYSTEM_LIBDEVICE_PATH = _p
|
||||
assert SYSTEM_LIBDEVICE_PATH is not None, \
|
||||
"Could not find libdevice.10.bc path"
|
||||
return SYSTEM_LIBDEVICE_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float32', 'libdevice.pow', system_libdevice_path()),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
@@ -1621,3 +1707,95 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
# -----------------------
|
||||
# test layout conversions
|
||||
# -----------------------
|
||||
# TODO: backend hsould be tested separately
|
||||
|
||||
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
x = to_triton(numpy_random(shape, dtype_str=dtype))
|
||||
z = torch.empty_like(x)
|
||||
|
||||
# write the IR to a temporary file using mkstemp
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
|
||||
|
||||
assert torch.equal(z, x)
|
||||
|
@@ -1,261 +0,0 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
22
python/test/unit/language/test_printf.py
Normal file
22
python/test/unit/language/test_printf.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
@@ -2,13 +2,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@pytest.mark.parametrize("TRANS_A", [False, True])
|
||||
@pytest.mark.parametrize("TRANS_B", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
# TODO: float32 fails
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
seed = 0
|
||||
@@ -32,9 +32,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
@@ -126,8 +126,8 @@ def test_attention_fwd_bwd(
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
# inputs
|
||||
|
@@ -2,20 +2,19 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and dtype == "bfloat16":
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8 and dtype == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
|
@@ -4,7 +4,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -67,10 +66,10 @@ import triton._C.libtriton.triton as _triton
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80 and DTYPE == "bfloat16":
|
||||
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
|
Reference in New Issue
Block a user