[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch
import triton
# -------------------------------
@@ -8,18 +9,18 @@ import triton
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
for op_mode in ['dsd'] for layout_mode in ['dense']
x_names=['M', 'N', 'K'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64, 128],
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel='TFLOPS',
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\
x_names=['M', 'N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64],
line_names=['Block16', 'Block32', 'Block64'],
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)
for layout_mode in ['dense', 'tril']
]
@@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)
bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -1,17 +1,18 @@
import torch
import triton
confs = [
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider',
line_vals = ['triton', 'torch'],
line_names = ['Triton', 'Torch'],
ylabel = 'GBPS',
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\
x_names=['N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)
for mode in ['forward', 'backward']
]
@@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
'triton': triton.ops.cross_entropy}[provider]
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
@@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider):
if __name__ == '__main__':
bench_op.run(print_data=True)
bench_op.run(print_data=True)

View File

@@ -1,6 +1,6 @@
import triton
import torch
import os
import triton
def rounded_linspace(low, high, steps, div):
@@ -29,16 +29,16 @@ square_confs = [
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\
for i, x in enumerate(["N", "K"])\
for M in [2048]
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except:
except Exception:
return None
return None

View File

@@ -1,7 +1,8 @@
import argparse
import sys
import os
import inspect
import os
import sys
import triton

View File

@@ -1,29 +1,28 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn
import urllib.request
import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return '', ''
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
@@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir):
try:
shutil.rmtree(os.path.join(dir, name))
except:
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
@@ -96,7 +95,7 @@ class CMakeBuild(build_ext):
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
]

View File

@@ -1,14 +1,18 @@
from numpy import record
import torch
import triton
import triton.language as tl
import subprocess
import sys
import pytest
import torch
from numpy import record
import triton
#######################
# Utilities
#######################
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -23,48 +27,51 @@ def nvsmi(attrs):
#######################
matmul_data = {
# square
(256 , 256 , 256 ) : {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.849},
# tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
# square
(256, 256, 256): {'v100': 0.027},
(512, 512, 512): {'v100': 0.158},
(1024, 1024, 1024): {'v100': 0.466},
(2048, 2048, 2048): {'v100': 0.680},
(4096, 4096, 4096): {'v100': 0.831},
(8192, 8192, 8192): {'v100': 0.849},
# tall-skinny
(16, 1024, 1024): {'v100': 0.0128},
(16, 4096, 4096): {'v100': 0.0883},
(16, 8192, 8192): {'v100': 0.101},
(64, 1024, 1024): {'v100': 0.073},
(64, 4096, 4096): {'v100': 0.270},
(64, 8192, 8192): {'v100': 0.360},
(1024, 64, 1024): {'v100': 0.0692},
(4096, 64, 4096): {'v100': 0.264},
(8192, 64, 8192): {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
}
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100']
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Element-Wise
#######################
import triton.language as tl
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
@@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = {
1024*16 : {'v100': 0.0219},
1024*64 : {'v100': 0.0791},
1024*256 : {'v100': 0.243},
1024*1024 : {'v100': 0.534},
1024*4096 : {'v100': 0.796},
1024*16384: {'v100': 0.905},
1024*65536: {'v100': 0.939},
1024 * 16: {'v100': 0.0219},
1024 * 64: {'v100': 0.0791},
1024 * 256: {'v100': 0.243},
1024 * 1024: {'v100': 0.534},
1024 * 4096: {'v100': 0.796},
1024 * 16384: {'v100': 0.905},
1024 * 65536: {'v100': 0.939},
}
@pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100']
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
@@ -102,7 +110,6 @@ def test_elementwise(N):
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
@@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
@@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', op, mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
('nan', 'nan')]
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
@@ -368,8 +376,8 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
@@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
@triton.jit
def fn(a, b):
return a + b, \
a - b, \
a * b
a - b, \
a * b
def test_tuples():
@@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False)
for dtype_x in dtypes
for dtype_z in dtypes
for dtype_x in dtypes
for dtype_z in dtypes
] + [
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
@@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
# triton result
@@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
[(dtype, shape)
for dtype in dtypes
for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel
@@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z)
# input
@@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
for perm in [(1, 0)]])
[(dtype, shape, perm)
for dtype in ['float32']
for shape in [(128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
# triton kernel
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
# triton kernel
@@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols')
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols')
# torch result
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:,0][:, None]
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0,:][None, :]
z_ref += z[0, :][None, :]
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out):
@@ -713,28 +727,30 @@ def test_dot_without_load():
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# ---------------
# test arange
# ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# ---------------
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32
@@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
reference_out =torch.matmul(in1, in2)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda')
@@ -796,8 +815,8 @@ def test_load_cache_modifier(cache):
@triton.jit
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x)
x = tl.load(src + offsets, cache_modifier=CACHE)
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
@@ -830,11 +849,14 @@ def test_load_cache_modifier(cache):
# ---------------
# test default
# ---------------
#TODO: can't be local to test_default
# TODO: can't be local to test_default
@triton.jit
def _impl(value = 10):
def _impl(value=10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -851,7 +873,9 @@ def test_default():
# ---------------
# test noop
#----------------
# ----------------
def test_noop(device='cuda'):
@triton.jit
def kernel(x):
@@ -861,9 +885,9 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:

View File

@@ -1,16 +1,17 @@
import torch
import triton
import triton.language as tl
import numpy as np
import pytest
import scipy.stats
import numpy as np
import torch
from numpy.random import Philox
import triton
import triton.language as tl
#####################################
## Reference Philox Implementation
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
#####################################
## Unit Tests
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask)
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :]

View File

@@ -1,14 +1,16 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\
for mode in ['forward', 'backward']
]
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
@@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode):
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)
triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -1,8 +1,10 @@
import pytest
import itertools
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
@@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1
@@ -73,11 +85,12 @@ def test_reuse():
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix():

View File

@@ -1,26 +1,26 @@
import ast
import builtins
import dbm
import functools
import inspect
import struct
import sys
import textwrap
import hashlib
import inspect
import os
import pickle
import struct
import subprocess
import os
import sys
import tempfile
import textwrap
import time
import warnings
from .tools.disasm import extract
from typing import Dict, Optional
import torch
from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from filelock import FileLock
import dbm
import tempfile
from typing import Optional, Dict
import time
from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor):
@@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1]
arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
@@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor):
fn.args[idx].name = arg_name
arg_values.append(fn.args[idx])
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
@@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
if else_bb:
self.builder.set_insert_block(else_bb)
is_terminator = self.visit_compound_statement(node.orelse)
#TODO: last statement is a terminator?
# TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
@@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
_builder=self.builder)
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),
self.visit(neg_cond_node),
_builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation
@@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
return fn(*args, **kws)
@@ -505,10 +503,10 @@ class Binary:
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
self.bin = bin
self.asm = bin.asm
@@ -520,8 +518,8 @@ class LoadedBinary:
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
grid_0, grid_1, grid_2,
self.bin.num_warps * 32, 1, 1,
grid_0, grid_1, grid_2,
self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem)
def get_sass(self, fun=None):
@@ -632,10 +630,14 @@ class Kernel:
@staticmethod
def pow2_divisor(N):
if N % 16 == 0: return 16
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
if N % 16 == 0:
return 16
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
def __init__(self, fn):
@@ -675,7 +677,7 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
@@ -705,7 +707,7 @@ class Kernel:
if binary is None:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
num_warps=num_warps, num_stages=num_stages,
constants=constants,
)
if bin_cache_path:
@@ -766,13 +768,12 @@ class Launcher:
def __call__(self, *wargs, **kwargs):
return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
@@ -788,6 +789,7 @@ class Autotuner:
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
@@ -802,7 +804,7 @@ class Autotuner:
perf_model, top_k, prune_num_stages_by = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
@@ -814,6 +816,7 @@ class Autotuner:
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
@@ -836,9 +839,9 @@ class Autotuner:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
@@ -876,7 +879,7 @@ def version_key():
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3
# 3
class DependenciesFinder(ast.NodeVisitor):
@@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor):
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
@@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor):
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction:
cache_hook = None
class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function
@@ -946,7 +949,6 @@ class JITFunction:
# forward docs
self.__doc__ = fn.__doc__
@property
@functools.lru_cache()
def cache_key(self):
@@ -1027,6 +1029,7 @@ class Config:
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
@@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
@@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
@@ -1099,7 +1102,7 @@ def heuristics(values):
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
@@ -1163,13 +1167,14 @@ def next_power_of_2(n):
######
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()

View File

@@ -1,8 +1,8 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps
import triton
from triton._C.libtriton.triton import frontend, ir
# convert block/dtype to ir values
def _to_ir(x, builder):
@@ -65,7 +65,7 @@ def builtin(fn):
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
@@ -111,6 +111,7 @@ class pointer_dtype:
def __str__(self):
return f'pointer<{self.element_ty}>'
# scalar types
int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8)
@@ -331,27 +332,27 @@ class constexpr:
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
@@ -360,25 +361,25 @@ class constexpr:
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
@@ -489,15 +490,16 @@ def broadcast_to(input, shape, _builder=None):
"""
return frontend.broadcast_to(input, shape, _builder)
@builtin
def cat(input, other, _builder=None):
"""
Concatenate the given blocks
:param input: The first input block.
:type input:
:type input:
:param other: The second input block.
:type other:
:type other:
"""
return frontend.cat(input, other, _builder)
@@ -508,7 +510,7 @@ def reshape(input, shape, _builder=None):
Tries to reshape the given block to a new shape.
:param input: The input block.
:type input:
:type input:
:param shape: The desired shape.
:type shape: Tuple[int]
@@ -546,7 +548,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
@@ -565,7 +567,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
@builtin
def store(pointer, value, mask=None, _builder=None):
"""
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
@@ -600,9 +602,10 @@ def _add_atomic_docstr(name):
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None):
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name):
def _decorator(func):
@@ -694,24 +699,28 @@ def _add_math_1arg_docstr(name):
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
return frontend.exp(x, _builder)
@builtin
@_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None):
return frontend.log(x, _builder)
@builtin
@_add_math_1arg_docstr("cosine")
def cos(x, _builder=None):
return frontend.cos(x, _builder)
@builtin
@_add_math_1arg_docstr("sine")
def sin(x, _builder=None):
@@ -739,9 +748,10 @@ def _add_reduction_docstr(name):
"""
func.__doc__ = docstr.format(name=name)
return func
return _decorator
@builtin
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder)
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None):
@@ -778,7 +789,7 @@ def debug_barrier(_builder=None):
@builtin
def multiple_of(input, value, _builder=None):
"""
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
return frontend.multiple_of(input, value, _builder)
@@ -786,7 +797,7 @@ def multiple_of(input, value, _builder=None):
@builtin
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, _builder)
@@ -794,7 +805,7 @@ def max_contiguous(input, value, _builder=None):
@builtin
def max_contiguous(input, value, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
"""
return frontend.max_contiguous(input, value, _builder)
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
def abs(x):
return where(x >= 0, x, -x)
@triton.jit
def cdiv(x, div):
"""
@@ -871,13 +883,14 @@ def ravel(x):
"""
return triton.language.reshape(x, [x.type.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
transformes indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
[[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11],
[12, 13, 14, 15]]
@@ -888,16 +901,16 @@ def swizzle2d(i, j, size_i, size_j, size_g):
[9, 11, 13, 15]]
"""
# "unrolled index in array"
ij = i*size_j + j
ij = i * size_j + j
# number of elements in `size_g` groups
# of `size_j` columns
size_gj = size_g * size_j
# index of the group in which (i,j) is
group_id = ij // size_gj
group_id = ij // size_gj
# row-index of the first element of this group
off_i = group_id * size_g
# last group may have fewer rows
size_g = minimum(size_i - off_i, size_g)
size_g = minimum(size_i - off_i, size_g)
# new row and column indices
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g

View File

@@ -1,7 +1,6 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53

View File

@@ -1,4 +1,4 @@
#from .conv import _conv, conv
from .matmul import _matmul, matmul
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse
from .matmul import _matmul, matmul

View File

@@ -1,2 +1,2 @@
from .matmul import matmul
from .softmax import softmax
from .softmax import softmax

View File

@@ -1,6 +1,7 @@
import torch
import triton
import triton.language as tl
import torch
# ********************************************************
# --------------------------------------------------------
@@ -11,16 +12,17 @@ import torch
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
@@ -30,25 +32,25 @@ def _sdd_kernel(
block_id = tl.program_id(1) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak)
a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha
+ offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak)
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk)
b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb
+ offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
@@ -69,13 +71,14 @@ def _sdd_kernel(
## ---------------- ##
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc)
pc = C + (off_z * stride_zc
+ block_id * stride_hc
+ offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
@@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
)
return c
@@ -119,50 +122,52 @@ def sdd_lut(layout, block, device):
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
#------------#
#- Prologue -#
#------------#
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
pidz = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N)
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
## ---------------- ##
## Inner Loop ##
## ---------------- ##
@@ -177,7 +182,7 @@ def _dsd_kernel(
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b*stride_bk
pb += inc_b * stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
@@ -185,15 +190,16 @@ def _dsd_kernel(
inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N)
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
@@ -231,7 +237,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten()
@@ -313,11 +319,11 @@ def dsd_lut(layout, block, step, trans, device):
# -----------------------------
@triton.jit
def _dds_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut,
A, B, C,
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
):
@@ -348,7 +354,7 @@ def _dds_kernel(
+ offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8)
block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \
@@ -429,7 +435,7 @@ class _matmul(torch.autograd.Function):
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
@@ -499,10 +505,10 @@ class matmul:
def __call__(self, a, b, out = None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
return c
return c

View File

@@ -1,7 +1,8 @@
import triton.language as tl
import triton
import torch
import triton
import triton.language as tl
def num_warps(n):
if n < 512:
@@ -33,10 +34,10 @@ def _forward(
check = rbn < size
rbmn = tl.where(check, rbn, size - 1)
# block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3)
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf'))
@@ -64,7 +65,7 @@ def _forward(
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# apply causal mask
is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
# computation
x = tl.softmax(x)
@@ -127,9 +128,9 @@ class _softmax(torch.autograd.Function):
@staticmethod
def forward(
ctx, x, scale, rpe,
key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode,
ctx, x, scale, rpe,
key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode,
is_causal,
spdims, block, lut, maxlut
):
@@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function):
# run kernel
M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK = block,
APPLY_SCALE = apply_scale,
APPLY_RPE = apply_rpe,
APPLY_KP_MASK = apply_kp_mask,
APPLY_ATTN_MASK = apply_attn_mask,
KP_MASK_MUL = (kp_mask_mode == 'mul'),
ATTN_MASK_MUL = (attn_mask_mode == 'mul'))
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK=block,
APPLY_SCALE=apply_scale,
APPLY_RPE=apply_rpe,
APPLY_KP_MASK=apply_kp_mask,
APPLY_ATTN_MASK=apply_attn_mask,
KP_MASK_MUL=(kp_mask_mode == 'mul'),
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
@@ -211,10 +212,10 @@ class softmax:
self.lut_cache = dict()
def __call__(
self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None,
self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add',
is_causal = False
is_causal=False
):
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -224,11 +225,11 @@ class softmax:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = _softmax.apply(
x, scale, rpe,
key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode,
x, scale, rpe,
key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode,
is_causal,
self.spdims, self.block,
self.spdims, self.block,
lut, maxlut
)
return x
return x

View File

@@ -1,7 +1,9 @@
import os
import torch
import triton
import triton.language as tl
import torch
def next_power_of_2(n):
@@ -104,4 +106,4 @@ class _cross_entropy(torch.autograd.Function):
return neg_logprobs, None
cross_entropy = _cross_entropy.apply
cross_entropy = _cross_entropy.apply

View File

@@ -1,11 +1,14 @@
import torch
import triton.language as tl
import triton
import triton.language as tl
from .matmul_perf_model import *
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
@@ -14,14 +17,15 @@ def get_configs_io_bound():
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
})
@@ -30,26 +34,26 @@ def get_configs_io_bound():
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'prune_num_stages_by' : prune_num_stages,
'perf_model': estimate_matmul_time,
'top_k': 10
'prune_num_stages_by': prune_num_stages,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
# matrix multiplication
@@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K,
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K):
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
@@ -117,10 +121,10 @@ class _matmul(torch.autograd.Function):
c = torch.empty((M, N), device=device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8)
return c

View File

@@ -1,116 +1,121 @@
import heapq
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
# backend, device,
num_warps, num_stages,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput
# time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas/num_sm)
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
# loading time in ms
load_ms = total_dram/dram_bw + total_l2/l2_bw
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms
load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram /store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram/reduce_bw
# c.zero_()
zero_ms = M*N*2/(1024*1024)/store_bw
store_ms += zero_ms
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram / store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram / reduce_bw
# c.zero_()
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
mma_cycles = mmas/min(4, num_warps) * 8
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency/mma_cycles
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency / mma_cycles
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs

View File

@@ -1,10 +1,11 @@
import torch
import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess
import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass
@@ -13,6 +14,7 @@ except ImportError:
_cutlass = None
has_cutlass = False
def catch_oor(kernel, pytest_handle=None):
try:
res = kernel()
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function
dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka,\
a.stride(0), a.stride(1),\
b.stride(0), b.stride(1),\
c.stride(0), c.stride(1),\
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
@@ -93,6 +96,7 @@ def nvsmi(attrs):
ret = [int(x) for x in ret]
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup/estimate_ms))
n_repeat = max(1, int(rep/estimate_ms))
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(n_warmup):
@@ -161,6 +165,7 @@ class Benchmark:
"""
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
"""
def __init__(
self,
x_names,
@@ -224,9 +229,10 @@ class Mark:
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt
import pandas as pd
import os
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
@@ -259,7 +265,7 @@ class Mark:
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel)
#ax.set_title(bench.plot_name)
# ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear")
if show_plots:
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA
@@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None):
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16
cc = _triton.runtime.cc(backend, device)
if cc < 80:
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:
ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
return tflops

View File

@@ -21,8 +21,8 @@
# SOFTWARE.
import argparse
import subprocess
import re
import subprocess
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')

View File

@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel
# --------------------------
from triton.language.core import constexpr
import torch
import triton
import triton.language as tl
@@ -38,7 +38,7 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case
# Load x and y from DRAM, masking out any extra elements in case
# the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)

View File

@@ -16,6 +16,8 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch
@@ -59,13 +61,10 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
@@ -136,7 +135,7 @@ y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch))
#%%
# %%
# As expected, the results are identical.
# %%
@@ -187,5 +186,5 @@ benchmark.run(show_plots=True, print_data=True)
# In the above plot, we can see that:
#
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.

View File

@@ -112,13 +112,13 @@ You will specifically learn about:
# # number of programs ids along the N axis
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# # number of programs in group
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# # id of the group this program is in
# group_id = pid // num_pid_in_group
# group_id = pid // num_pid_in_group
# # row-id of the first program in the group
# first_pid_m = group_id * GROUP_SIZE_M
# first_pid_m = group_id * GROUP_SIZE_M
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# # *within groups*, programs are ordered in a column-major order
# # row-id of the program in the *launch grid*
# pid_m = first_pid_m + (pid % group_size_m)
@@ -141,6 +141,7 @@ You will specifically learn about:
#
import torch
import triton
import triton.language as tl
@@ -152,18 +153,19 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@@ -185,7 +187,7 @@ def matmul_kernel(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
@@ -196,16 +198,16 @@ def matmul_kernel(
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
@@ -213,8 +215,8 @@ def matmul_kernel(
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
@@ -223,8 +225,8 @@ def matmul_kernel(
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
@@ -236,7 +238,7 @@ def matmul_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32 !
if meta['ACTIVATION']:
if meta['ACTIVATION']:
accumulator = meta['ACTIVATION'](accumulator)
c = accumulator.to(tl.float16)

View File

@@ -13,7 +13,7 @@ whose state is generally composed of a bit mask tensor of the same shape as the
# %%
# Baseline
# -------------
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
# of deep neural networks in low-data regime (i.e. regularization).
#
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
@@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
**meta,
):
BLOCK_SIZE = meta['BLOCK_SIZE']
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
@@ -88,7 +91,7 @@ print(tabulate.tabulate([
# of persisting randomness across multiple invocations of the kernel.
#
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
# other :ref:`random number generation strategies <Random Number Generation>`.
#
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
#
# Let's put it all together.
@triton.jit
def _seeded_dropout(
x_ptr,

View File

@@ -4,15 +4,17 @@ Layer Normalization
"""
import torch
import triton.language as tl
import triton
import triton.language as tl
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
# position of elements processed by this program
row = tl.program_id(0)
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
@@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean*rstd
var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
@@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
stride, N, eps,
**META):
stride, N, eps,
**META):
GROUP_SIZE_M = META['GROUP_SIZE_M']
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
# position of elements processed by this program
row = tl.program_id(0)
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
@@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id*N + cols
DB = DB + lock_id*N + cols
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean)*rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy , 0.)
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat*mean1 + mean2))*rstd
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy*xhat).to(w.dtype)
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
@@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0)
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None]*N + cols[None, :]
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
tl.store(FINAL_DB + cols, sum_db, mask=cols<N)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@@ -129,19 +134,19 @@ class LayerNorm(torch.autograd.Function):
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
@@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function):
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
@@ -166,14 +171,14 @@ class LayerNorm(torch.autograd.Function):
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M = 32,
BLOCK_SIZE_N = 128)
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
@@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x)
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
@@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x)
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# utility functions
if provider == 'triton':
@@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
y_fwd = lambda: apex_layer_norm(x)
# forward pass
if mode == 'forward':
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)