diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index b6eacd884..d678f49f8 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -1,4 +1,5 @@ import torch + import triton # ------------------------------- @@ -8,18 +9,18 @@ import triton nt = {False: 'n', True: 't'} square_confs = [ triton.testing.Benchmark( - x_names = ['M', 'N', 'K'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg = 'block', - line_vals = [16, 32, 64, 128], - line_names = ['Block16', 'Block32', 'Block64', 'Block128'], - ylabel = 'TFLOPS', - plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', - args = {'layout_mode': layout_mode, 'op_mode': op_mode, - 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} - )\ - for AT in [False] for BT in [False] \ - for op_mode in ['dsd'] for layout_mode in ['dense'] + x_names=['M', 'N', 'K'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], + line_arg='block', + line_vals=[16, 32, 64, 128], + line_names=['Block16', 'Block32', 'Block64', 'Block128'], + ylabel='TFLOPS', + plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', + args={'layout_mode': layout_mode, 'op_mode': op_mode, + 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} + ) + for AT in [False] for BT in [False] + for op_mode in ['dsd'] for layout_mode in ['dense'] ] @@ -27,7 +28,7 @@ square_confs = [ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): Z, H = 1, 1 make_layout = { - 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ + 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), }[layout_mode] # create layout @@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) num_flops = { - 'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ - 'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ + 'sdd': 2 * Z * K * float(layout.sum()) * block * block, + 'dsd': 2 * Z * N * float(layout.sum()) * block * block, 'dds': 2 * Z * M * float(layout.sum()) * block * block - }[op_mode]*1e-12 + }[op_mode] * 1e-12 return tflops(mean_ms), tflops(min_ms), tflops(max_ms) @@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, square_confs = [ triton.testing.Benchmark( - x_names = ['M', 'N'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg = 'block', - line_vals = [16, 32, 64], - line_names = ['Block16', 'Block32', 'Block64'], - ylabel = 'GBPS', - plot_name = f'{layout_mode}-square', - args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} - )\ + x_names=['M', 'N'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], + line_arg='block', + line_vals=[16, 32, 64], + line_names=['Block16', 'Block32', 'Block64'], + ylabel='GBPS', + plot_name=f'{layout_mode}-square', + args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} + ) for layout_mode in ['dense', 'tril'] ] @@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): return gbps(mean_ms), gbps(min_ms), gbps(max_ms) -bench_matmul.run(print_data=True, show_plots=True) \ No newline at end of file +bench_matmul.run(print_data=True, show_plots=True) diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py index 5347ae24a..aaa0e28f5 100644 --- a/python/bench/bench_cross_entropy.py +++ b/python/bench/bench_cross_entropy.py @@ -1,17 +1,18 @@ import torch + import triton confs = [ triton.testing.Benchmark( - x_names = ['N'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], - line_arg = 'provider', - line_vals = ['triton', 'torch'], - line_names = ['Triton', 'Torch'], - ylabel = 'GBPS', - plot_name = f'{mode}-2048', - args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} - )\ + x_names=['N'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], + line_arg='provider', + line_vals=['triton', 'torch'], + line_names=['Triton', 'Torch'], + ylabel='GBPS', + plot_name=f'{mode}-2048', + args={'M': 2048, 'dtype': torch.float16, 'mode': mode} + ) for mode in ['forward', 'backward'] ] @@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider): num_gb = (2 * x.numel() * x.element_size() * 1e-9) gbps = lambda ms: num_gb / ms * 1e3 # forward pass - op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ - 'triton': triton.ops.cross_entropy}[provider] + op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), + 'triton': triton.ops.cross_entropy}[provider] if mode == 'forward': mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) if mode == 'backward': @@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider): if __name__ == '__main__': - bench_op.run(print_data=True) \ No newline at end of file + bench_op.run(print_data=True) diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 7e912be31..9db005da0 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -1,6 +1,6 @@ -import triton import torch -import os + +import triton def rounded_linspace(low, high, steps, div): @@ -29,16 +29,16 @@ square_confs = [ transformer_confs = [ triton.testing.Benchmark( x_names=[x], - x_vals = rounded_linspace(NK//16, NK, 32, 128), + x_vals=rounded_linspace(NK // 16, NK, 32, 128), line_arg="provider", line_vals=["cublas", "triton", "cutlass"], line_names=["cuBLAS", "Triton", "CUTLASS"], ylabel="TFLOPS", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", - args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} - ) for NK in [12288]\ - for i, x in enumerate(["N", "K"])\ - for M in [2048] + args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16} + ) for NK in [12288] + for i, x in enumerate(["N", "K"]) + for M in [2048] ] @@ -46,8 +46,10 @@ transformer_confs = [ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) - if AT: a = a.t() - if BT: b = b.t() + if AT: + a = a.t() + if BT: + b = b.t() num_flops = 2 * M * N * K tflops = lambda ms: 2. * M * N * K / ms * 1e-9 if provider == "cublas": @@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): try: ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) return tflops(ms), tflops(max_ms), tflops(min_ms) - except: + except Exception: return None return None diff --git a/python/bench/run.py b/python/bench/run.py index c23884bb5..5e6e3b392 100644 --- a/python/bench/run.py +++ b/python/bench/run.py @@ -1,7 +1,8 @@ import argparse -import sys -import os import inspect +import os +import sys + import triton diff --git a/python/setup.py b/python/setup.py index 17db76093..28194f41e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,29 +1,28 @@ -import os -import re -import sys -import sysconfig -import platform -import subprocess import distutils -import glob -import tempfile -import shutil -from distutils.version import LooseVersion -from setuptools import setup, Extension, find_packages -from setuptools.command.build_ext import build_ext -from setuptools.command.test import test as TestCommand import distutils.spawn -import urllib.request +import os +import platform +import re +import shutil +import subprocess +import sys import tarfile +import tempfile +import urllib.request +from distutils.version import LooseVersion + +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + def get_llvm(): # tries to find system LLVM - versions = ['-11.0', '-11', '-11-64'] + versions = ['-11.0', '-11', '-11-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] if paths: - return '', '' + return '', '' # download if nothing is installed name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' dir = '/tmp' @@ -32,7 +31,7 @@ def get_llvm(): if not os.path.exists(llvm_library_dir): try: shutil.rmtree(os.path.join(dir, name)) - except: + except Exception: pass url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) print('downloading and extracting ' + url + '...') @@ -96,7 +95,7 @@ class CMakeBuild(build_ext): "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir, #'-DPYTHON_EXECUTABLE=' + sys.executable, - #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', + # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) ] diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index ce93786b8..012ff65d7 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,14 +1,18 @@ -from numpy import record -import torch -import triton +import triton.language as tl import subprocess import sys + import pytest +import torch +from numpy import record + +import triton ####################### # Utilities ####################### + def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -23,48 +27,51 @@ def nvsmi(attrs): ####################### matmul_data = { - # square - (256 , 256 , 256 ) : {'v100': 0.027}, - (512 , 512 , 512 ) : {'v100': 0.158}, - (1024, 1024, 1024 ) : {'v100': 0.466}, - (2048, 2048, 2048 ) : {'v100': 0.680}, - (4096, 4096, 4096 ) : {'v100': 0.831}, - (8192, 8192, 8192 ) : {'v100': 0.849}, - # tall-skinny - (16 , 1024, 1024 ) : {'v100': 0.0128}, - (16 , 4096, 4096 ) : {'v100': 0.0883}, - (16 , 8192, 8192 ) : {'v100': 0.101}, - (64 , 1024, 1024 ) : {'v100': 0.073}, - (64 , 4096, 4096 ) : {'v100': 0.270}, - (64 , 8192, 8192 ) : {'v100': 0.360}, - (1024, 64 , 1024 ) : {'v100': 0.0692}, - (4096, 64 , 4096 ) : {'v100': 0.264}, - (8192, 64 , 8192 ) : {'v100': 0.323}, -# # deep reductions -# (64 , 64 , 16384) : {'v100': 0.}, -# (64 , 64 , 65536) : {'v100': 0.}, -# (256 , 256 , 8192 ) : {'v100': 0.}, -# (256 , 256 , 32768) : {'v100': 0.}, + # square + (256, 256, 256): {'v100': 0.027}, + (512, 512, 512): {'v100': 0.158}, + (1024, 1024, 1024): {'v100': 0.466}, + (2048, 2048, 2048): {'v100': 0.680}, + (4096, 4096, 4096): {'v100': 0.831}, + (8192, 8192, 8192): {'v100': 0.849}, + # tall-skinny + (16, 1024, 1024): {'v100': 0.0128}, + (16, 4096, 4096): {'v100': 0.0883}, + (16, 8192, 8192): {'v100': 0.101}, + (64, 1024, 1024): {'v100': 0.073}, + (64, 4096, 4096): {'v100': 0.270}, + (64, 8192, 8192): {'v100': 0.360}, + (1024, 64, 1024): {'v100': 0.0692}, + (4096, 64, 4096): {'v100': 0.264}, + (8192, 64, 8192): {'v100': 0.323}, + # # deep reductions + # (64 , 64 , 16384) : {'v100': 0.}, + # (64 , 64 , 65536) : {'v100': 0.}, + # (256 , 256 , 8192 ) : {'v100': 0.}, + # (256 , 256 , 32768) : {'v100': 0.}, } + + @pytest.mark.parametrize('M, N, K', matmul_data.keys()) def test_matmul(M, N, K): ref_gpu_util = matmul_data[(M, N, K)]['v100'] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 - max_gpu_perf = 1e-6*80*8*128*cur_sm_clock + max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' a = torch.randn((M, K), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda') fn = lambda: triton.ops.matmul(a, b) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) - cur_gpu_perf = 2.*M*N*K/ms * 1e-9 + cur_gpu_perf = 2. * M * N * K / ms * 1e-9 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + ####################### # Element-Wise ####################### -import triton.language as tl + @triton.jit def _add(x_ptr, y_ptr, output_ptr, n_elements, @@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { - 1024*16 : {'v100': 0.0219}, - 1024*64 : {'v100': 0.0791}, - 1024*256 : {'v100': 0.243}, - 1024*1024 : {'v100': 0.534}, - 1024*4096 : {'v100': 0.796}, - 1024*16384: {'v100': 0.905}, - 1024*65536: {'v100': 0.939}, + 1024 * 16: {'v100': 0.0219}, + 1024 * 64: {'v100': 0.0791}, + 1024 * 256: {'v100': 0.243}, + 1024 * 1024: {'v100': 0.534}, + 1024 * 4096: {'v100': 0.796}, + 1024 * 16384: {'v100': 0.905}, + 1024 * 65536: {'v100': 0.939}, } + @pytest.mark.parametrize('N', elementwise_data.keys()) def test_elementwise(N): ref_gpu_util = elementwise_data[N]['v100'] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 - max_gpu_perf = 512*2*ref_mem_clock*1e-3 + max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3 assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) @@ -102,7 +110,6 @@ def test_elementwise(N): grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) - cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 + cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) - diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 41c9e9236..7f0af78b4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -86,6 +86,7 @@ def patch_kernel(template, to_replace): @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 + @triton.jit def kernel(X, SIZE: tl.constexpr): pass @@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid + @triton.jit def kernel(Z, X, SIZE: tl.constexpr): off = tl.arange(0, SIZE) @@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): SIZE = 128 # define the kernel / launch-grid + @triton.jit def kernel(Z, X, Y, SIZE: tl.constexpr): off = tl.arange(0, SIZE) @@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: # --------------- # test binary ops # --------------- + + @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) - for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes - for dtype_y in dtypes + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes + for dtype_y in dtypes ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' @@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y", - [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + - [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] -) + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] + ) def test_floordiv(dtype_x, dtype_y, device='cuda'): # Triton has IEEE, not numpy/torch, semantics for %, and those carry # through to //, so we have to use a nonstandard expression to get a @@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'): # test compare ops # --------------- ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \ -# real -[ - (dtype_x, dtype_y, op, 'real', 'real') \ - for op in ops \ - for dtype_x in dtypes \ - for dtype_y in dtypes -] + \ -# NaNs -[('float32', 'float32', op, mode_x, mode_y) \ - for op in ops - for mode_x, mode_y in [('nan' , 'real'), - ('real', 'nan'), - ('nan' , 'nan')] -]) + +@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", + # real + [ + (dtype_x, dtype_y, op, 'real', 'real') + for op in ops + for dtype_x in dtypes + for dtype_y in dtypes + ] + + # NaNs + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), + ('real', 'nan'), + ('nan', 'nan')] + + ]) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): expr = f'x {op} y' if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): @@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): # 'exp', 'log', 'cos', 'sin' # ]) + @pytest.mark.parametrize("expr", [ 'exp', 'log', 'cos', 'sin' ]) @@ -368,8 +376,8 @@ def make_ptr_str(name, shape): @pytest.mark.parametrize("expr, dtype_str", [ (f'x[{s}]', d) - for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] - for d in ['int32', 'uint32', 'uint16'] + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') @@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'): @triton.jit def fn(a, b): return a + b, \ - a - b, \ - a * b + a - b, \ + a * b def test_tuples(): @@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): # --------------- @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ (dtype_x, dtype_z, False) - for dtype_x in dtypes - for dtype_z in dtypes + for dtype_x in dtypes + for dtype_z in dtypes ] + [ ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), @@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): @triton.jit def kernel(X, Z, BITCAST: tl.constexpr): x = tl.load(X) - z = x.to(Z.dtype.element_ty, bitcast = BITCAST) + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) tl.store(Z, z) # triton result @@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # --------------- # test reduce # --------------- + + @pytest.mark.parametrize("dtype_str, shape", - [(dtype, shape) \ - for dtype in dtypes\ - for shape in [128, 512]]) + [(dtype, shape) + for dtype in dtypes + for shape in [128, 512]]) def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel @@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) - x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input @@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'): # --------------- # test permute # --------------- + + @pytest.mark.parametrize("dtype_str, shape, perm", - [(dtype, shape, perm) \ - for dtype in ['float32']\ - for shape in [(128, 128)]\ - for perm in [(1, 0)]]) + [(dtype, shape, perm) + for dtype in ['float32'] + for shape in [(128, 128)] + for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): # triton kernel @@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # test dot # --------------- + @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) def test_dot(epilogue, device='cuda'): # triton kernel @@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'): y_tri, y_tri.stride(0), y_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX = epilogue=='add-matrix', - ADD_ROWS = epilogue=='add-rows', - ADD_COLS = epilogue=='add-cols') + ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', + ADD_COLS=epilogue == 'add-cols') # torch result z_ref = np.matmul(x, y) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': - z_ref += z[:,0][:, None] + z_ref += z[:, 0][:, None] if epilogue == 'add-cols': - z_ref += z[0,:][None, :] + z_ref += z[0, :][None, :] # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized @@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'): assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + def test_dot_without_load(): @triton.jit def kernel(out): @@ -713,28 +727,30 @@ def test_dot_without_load(): b = tl.zeros((32, 32), tl.float32) c = tl.zeros((32, 32), tl.float32) c = tl.dot(a, b) - pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] + pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(pout, c) - out = torch.ones((32,32), dtype=torch.float32, device="cuda") + out = torch.ones((32, 32), dtype=torch.float32, device="cuda") kernel[(1,)](out) # --------------- # test arange # --------------- + @pytest.mark.parametrize("start", [0, 1, 7, 16]) def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + @triton.jit def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) - z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) + _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) triton.testing.assert_almost_equal(z_tri, z_ref) # --------------- @@ -742,6 +758,8 @@ def test_arange(start, device='cuda'): # --------------- # 'bfloat16': torch.bfloat16, # Testing masked loads with an intermate copy to shared memory run. + + @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): M = 32 @@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): N_offsets = tl.arange(0, N) K_offsets = tl.arange(0, K) - in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] - in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] # Load inputs. x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) @@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'): o = tl.dot(x, w) # Store output - output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) pgm = _kernel[(1,)](in1, in2, out, - in1.stride()[0], - in2.stride()[0], - out.stride()[0], - in1.numel(), - in2.numel(), - out.numel(), - M=M, N=N, K=K) + in1.stride()[0], + in2.stride()[0], + out.stride()[0], + in1.numel(), + in2.numel(), + out.numel(), + M=M, N=N, K=K) - reference_out =torch.matmul(in1, in2) + reference_out = torch.matmul(in1, in2) triton.testing.allclose(out, reference_out) + @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) def test_load_cache_modifier(cache): src = torch.empty(128, device='cuda') @@ -796,8 +815,8 @@ def test_load_cache_modifier(cache): @triton.jit def _kernel(dst, src, CACHE: tl.constexpr): offsets = tl.arange(0, 128) - x = tl.load(src+offsets, cache_modifier=CACHE) - tl.store(dst+offsets, x) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] @@ -830,11 +849,14 @@ def test_load_cache_modifier(cache): # --------------- # test default # --------------- -#TODO: can't be local to test_default +# TODO: can't be local to test_default + + @triton.jit -def _impl(value = 10): +def _impl(value=10): return value + def test_default(): value = 5 ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') @@ -851,7 +873,9 @@ def test_default(): # --------------- # test noop -#---------------- +# ---------------- + + def test_noop(device='cuda'): @triton.jit def kernel(x): @@ -861,9 +885,9 @@ def test_noop(device='cuda'): @pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), - (2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') ]) def test_value_specialization(value: int, value_type: str, device='cuda') -> None: diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 67173adfb..82ae7f0c2 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -1,16 +1,17 @@ -import torch -import triton -import triton.language as tl +import numpy as np import pytest import scipy.stats -import numpy as np - +import torch from numpy.random import Philox +import triton +import triton.language as tl + ##################################### -## Reference Philox Implementation +# Reference Philox Implementation ##################################### + class PhiloxConfig: def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) @@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x): ##################################### -## Unit Tests +# Unit Tests ##################################### BLOCK = 1024 # test generation of random uint32 + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in ['10', '4,53', '10000']\ - for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] -) + [(size, seed) for size in ['10', '4,53', '10000'] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] + ) def test_randint(size, seed, device='cuda'): size = list(map(int, size.split(','))) + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) @@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'): assert out_tri == out_ref # test uniform PRNG + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000]\ - for seed in [0, 42, 124, 54]] -) + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) def test_rand(size, seed, device='cuda'): @triton.jit def kernel(X, N, seed): @@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'): assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 # test normal PRNG + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000]\ - for seed in [0, 42, 124, 54]] -) + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) def test_randn(size, seed, device='cuda'): @triton.jit def kernel(X, N, seed): diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index b9cdc23c7..ed569c04d 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -1,6 +1,7 @@ -import torch -import triton import pytest +import torch + +import triton @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE): # torch result rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) # broadcast at_mask to the same shape as rx - if is_causal: at_mask = torch.tril(at_mask) + if is_causal: + at_mask = torch.tril(at_mask) M = at_mask[None, None, :, :] + torch.zeros_like(rx) rx[M == 0] = float("-inf") # rx += kp_mask[:, None, None, :] diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 48cb303bb..08516257b 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -1,14 +1,16 @@ -import torch -import triton import pytest +import torch + +import triton + @pytest.mark.parametrize("M, N, dtype, mode", - [ - (M, N, dtype, mode) for M in [1024, 821] - for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32']\ - for mode in ['forward', 'backward'] - ] + [ + (M, N, dtype, mode) for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] + ] ) def test_op(M, N, dtype, mode): dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] @@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode): x.grad.zero_() th_y.backward(dy) th_dx = x.grad.clone() - triton.testing.assert_almost_equal(th_dx, tt_dx) \ No newline at end of file + triton.testing.assert_almost_equal(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 75241c291..1d413a0e6 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -1,8 +1,10 @@ -import pytest import itertools -import triton + +import pytest import torch +import triton + @pytest.mark.parametrize( "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", @@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, K = BLOCK_K * SPLIT_K if K is None else K # allocate/transpose inputs DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] - a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) - b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) + a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) + b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a b = b.t() if BT else b # run test th_c = torch.matmul(a, b) - tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) + tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest) triton.testing.assert_almost_equal(th_c, tt_c) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 3ad387f09..51c69b5b6 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,13 +1,16 @@ -import torch -import triton -from triton.code_gen import JITFunction -import triton.language as tl import os import shutil + import pytest +import torch + +import triton +import triton.language as tl +from triton.code_gen import JITFunction tmpdir = ".tmp" + @triton.jit def function_1(i): i = i + 1 @@ -20,18 +23,21 @@ def function_2(i): i = i + 1 return i + @triton.jit def kernel(X, i, BLOCK: tl.constexpr): i = i + 1 i = function_1(i) tl.store(X, i) + @triton.jit(do_not_specialize=["i"]) def kernel_nospec(X, i, BLOCK: tl.constexpr): i = i + 1 i = function_1(i) tl.store(X, i) + def apply_src_change(target, old, new): delattr(kernel.fn, 'hash') delattr(function_1.fn, 'hash') @@ -42,28 +48,34 @@ def apply_src_change(target, old, new): target.src = target.src.replace(new, old) return ret + def test_nochange(): baseline = kernel.cache_key updated = apply_src_change(kernel, 'i + 1', 'i + 1') assert baseline == updated + def test_toplevel_change(): baseline = kernel.cache_key updated = apply_src_change(kernel, 'i + 1', 'i + 2') assert baseline != updated + def test_nested1_change(): baseline = kernel.cache_key updated = apply_src_change(function_1, 'i + 1', 'i + 2') assert baseline != updated + def reset_tmp_dir(): os.environ["TRITON_CACHE_DIR"] = tmpdir if os.path.exists(tmpdir): shutil.rmtree(tmpdir) + def test_reuse(): counter = 0 + def inc_counter(key, binary, repr): nonlocal counter counter += 1 @@ -73,11 +85,12 @@ def test_reuse(): for i in range(10): kernel[(1,)](x, 1, BLOCK=1024) assert counter == 1 - + @pytest.mark.parametrize('mode', ['enable', 'disable']) def test_specialize(mode): counter = 0 + def inc_counter(key, binary, repr): nonlocal counter counter += 1 diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py index ae843a15f..6d0658f3b 100644 --- a/python/test/unit/runtime/test_comm.py +++ b/python/test/unit/runtime/test_comm.py @@ -1,9 +1,11 @@ -import torch -import triton -import pytest import subprocess -import triton.language as tl + import numpy as np +import pytest +import torch + +import triton +import triton.language as tl def get_p2p_matrix(): diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index eec36f052..439c1798e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,26 +1,26 @@ import ast import builtins +import dbm import functools -import inspect -import struct -import sys -import textwrap import hashlib +import inspect import os import pickle +import struct import subprocess -import os +import sys +import tempfile +import textwrap +import time import warnings -from .tools.disasm import extract +from typing import Dict, Optional + import torch +from filelock import FileLock + import triton import triton._C.libtriton.triton as _triton -from filelock import FileLock -import dbm -import tempfile -from typing import Optional, Dict -import time - +from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): @@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor): arg_names, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): - arg_node = node.args.args[-i-1] + arg_node = node.args.args[-i - 1] annotation = arg_node.annotation name = arg_node.arg st_target = ast.Name(id=name, ctx=ast.Store()) @@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor): fn.args[idx].name = arg_name arg_values.append(fn.args[idx]) idx += 1 - - + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: @@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor): # default: call visit_Assign return self.visit_Assign(node) - def visit_Assign(self, node): _names = [] for target in node.targets: @@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor): if else_bb: self.builder.set_insert_block(else_bb) is_terminator = self.visit_compound_statement(node.orelse) - #TODO: last statement is a terminator? + # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) self.module.seal_block(endif_bb) @@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor): pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) - build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ - self.visit(pos_cond_node),\ - self.visit(neg_cond_node),\ - _builder=self.builder) + build_cond = lambda: triton.language.where(self.visit(pos_step_node), + self.visit(pos_cond_node), + self.visit(neg_cond_node), + _builder=self.builder) #cond_node = neg_cond_node step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation @@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core: + sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) @@ -505,10 +503,10 @@ class Binary: class LoadedBinary: def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, + module, kernel = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, device) self.bin = bin self.asm = bin.asm @@ -520,8 +518,8 @@ class LoadedBinary: def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, + grid_0, grid_1, grid_2, + self.bin.num_warps * 32, 1, 1, args, self.bin.shared_mem) def get_sass(self, fun=None): @@ -632,10 +630,14 @@ class Kernel: @staticmethod def pow2_divisor(N): - if N % 16 == 0: return 16 - if N % 8 == 0: return 8 - if N % 4 == 0: return 4 - if N % 2 == 0: return 2 + if N % 16 == 0: + return 16 + if N % 8 == 0: + return 8 + if N % 4 == 0: + return 4 + if N % 2 == 0: + return 2 return 1 def __init__(self, fn): @@ -675,7 +677,7 @@ class Kernel: tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int) and i not in self.fn.do_not_specialize} # transforms ints whose value is one into constants for just-in-time compilation @@ -705,7 +707,7 @@ class Kernel: if binary is None: binary = self._compile( *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, + num_warps=num_warps, num_stages=num_stages, constants=constants, ) if bin_cache_path: @@ -766,13 +768,12 @@ class Launcher: def __call__(self, *wargs, **kwargs): return self.kernel(*wargs, **kwargs, grid=self.grid) - class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): + def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: + :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. @@ -788,6 +789,7 @@ class Autotuner: self.hook = lambda args: 0 if reset_to_zero is not None: self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + def _hook(args): for i in self.reset_idx: args[i].zero_() @@ -802,7 +804,7 @@ class Autotuner: perf_model, top_k, prune_num_stages_by = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.prune_num_stages_by = prune_num_stages_by - + def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -814,6 +816,7 @@ class Autotuner: ) # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) + def kernel_call(): if config.pre_hook: config.pre_hook(self.nargs) @@ -836,9 +839,9 @@ class Autotuner: top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) \ + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start @@ -876,7 +879,7 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -#########################3 +# 3 class DependenciesFinder(ast.NodeVisitor): @@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor): def visit_Name(self, node): return self.globals.get(node.id, None) - + def visit_Attribute(self, node): lhs = self.visit(node.value) while isinstance(lhs, ast.Attribute): @@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor): self.ret = (self.ret + func.hash).encode("utf-8") self.ret = hashlib.md5(self.ret).hexdigest() -class JITFunction: - - cache_hook = None +class JITFunction: + + cache_hook = None def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -946,7 +949,6 @@ class JITFunction: # forward docs self.__doc__ = fn.__doc__ - @property @functools.lru_cache() def cache_key(self): @@ -1027,6 +1029,7 @@ class Config: :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this function are args. """ + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps @@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): .. highlight:: python .. code-block:: python - @triton.autotune(configs=[ + @triton.autotune(configs=[ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], + ], key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes + # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] - + :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. + This means that whatever value the kernel updates will be updated multiple times. To avoid this undesired behavior, you can use the `reset_to_zero` argument, which reset the value of the provided tensor to `zero` before running any configuration. @@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): :type configs: list[triton.Config] :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: + :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. @@ -1099,7 +1102,7 @@ def heuristics(values): def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. each such function takes a list of positional arguments as input. .type values: dict[str, Callable[[list[Any]], Any]] @@ -1150,6 +1153,7 @@ def jit(*args, **kwargs): def cdiv(x, y): return (x + y - 1) // y + def next_power_of_2(n): """Return the smallest power of 2 greater than or equal to n""" n -= 1 @@ -1163,13 +1167,14 @@ def next_power_of_2(n): ###### + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype - self.base = base + self.base = base self.is_cuda = base.is_cuda self.device = base.device - + def data_ptr(self): return self.base.data_ptr() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 210a72a30..6895c101c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,8 +1,8 @@ -import triton -from triton._C.libtriton.triton import ir -from triton._C.libtriton.triton import frontend from functools import wraps +import triton +from triton._C.libtriton.triton import frontend, ir + # convert block/dtype to ir values def _to_ir(x, builder): @@ -65,7 +65,7 @@ def builtin(fn): def wrapper(*args, **kwargs): if '_builder' not in kwargs or \ kwargs['_builder'] is None: - raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) return wrapper @@ -111,6 +111,7 @@ class pointer_dtype: def __str__(self): return f'pointer<{self.element_ty}>' + # scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8) @@ -331,27 +332,27 @@ class constexpr: def __rsub__(self, other): return other.value - self.value - + def __mul__(self, other): return self.value * other.value def __rmul__(self, other): return other.value * self.value - + def __truediv__(self, other): return self.value / other.value - + def __rtruediv__(self, other): return other.value / self.value - + def __floordiv__(self, other): return self.value // other.value - + def __rfloordiv__(self, other): return other.value // self.value # - + def __gt__(self, other): return self.value > other.value @@ -360,25 +361,25 @@ class constexpr: def __ge__(self, other): return self.value >= other.value - + def __rge__(self, other): return other.value >= self.value - + def __lt__(self, other): return self.value < other.value def __rlt__(self, other): return other.value < self.value - + def __le__(self, other): return self.value <= other.value - + def __rle__(self, other): return other.value <= self.value - + def __eq__(self, other): return self.value == other.value - + def __ne__(self, other): return self.value != other.value @@ -489,15 +490,16 @@ def broadcast_to(input, shape, _builder=None): """ return frontend.broadcast_to(input, shape, _builder) + @builtin def cat(input, other, _builder=None): """ Concatenate the given blocks :param input: The first input block. - :type input: + :type input: :param other: The second input block. - :type other: + :type other: """ return frontend.cat(input, other, _builder) @@ -508,7 +510,7 @@ def reshape(input, shape, _builder=None): Tries to reshape the given block to a new shape. :param input: The input block. - :type input: + :type input: :param shape: The desired shape. :type shape: Tuple[int] @@ -546,7 +548,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. - :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. + :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. :code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`. @@ -565,7 +567,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. @@ -600,9 +602,10 @@ def _add_atomic_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator - + + @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): @@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None): return frontend.atomic_xchg(pointer, val, mask, _builder) + @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): @@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None): def umulhi(x, y, _builder=None): return frontend.umulhi(x, y, _builder) + def _add_math_1arg_docstr(name): def _decorator(func): @@ -694,24 +699,28 @@ def _add_math_1arg_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator + @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): return frontend.exp(x, _builder) + @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): return frontend.log(x, _builder) + @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): return frontend.cos(x, _builder) + @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): @@ -739,9 +748,10 @@ def _add_reduction_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator + @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): @@ -759,6 +769,7 @@ def min(input, axis, _builder=None): def sum(input, axis, _builder=None): return frontend.sum(input, axis, _builder) + @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): @@ -778,7 +789,7 @@ def debug_barrier(_builder=None): @builtin def multiple_of(input, value, _builder=None): """ - Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. + Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ return frontend.multiple_of(input, value, _builder) @@ -786,7 +797,7 @@ def multiple_of(input, value, _builder=None): @builtin def max_contiguous(input, value, _builder=None): """ - Let the compiler knows that the `value` first values in :code:`input` are contiguous. + Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ return frontend.max_contiguous(input, value, _builder) @@ -794,7 +805,7 @@ def max_contiguous(input, value, _builder=None): @builtin def max_contiguous(input, value, _builder=None): """ - Let the compiler knows that the `value` first values in :code:`input` are contiguous. + Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ return frontend.max_contiguous(input, value, _builder) @@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None): def abs(x): return where(x >= 0, x, -x) + @triton.jit def cdiv(x, div): """ @@ -871,13 +883,14 @@ def ravel(x): """ return triton.language.reshape(x, [x.type.numel]) + @triton.jit def swizzle2d(i, j, size_i, size_j, size_g): """ transformes indices of a row-major size_i*size_j matrix into those of one where indices are row major for each group of size_j rows. For example, for size_i = size_j = 4 and size_g = 2, it will transform - [[0 , 1 , 2 , 3 ], + [[0 , 1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 , 9 , 10, 11], [12, 13, 14, 15]] @@ -888,16 +901,16 @@ def swizzle2d(i, j, size_i, size_j, size_g): [9, 11, 13, 15]] """ # "unrolled index in array" - ij = i*size_j + j + ij = i * size_j + j # number of elements in `size_g` groups # of `size_j` columns size_gj = size_g * size_j # index of the group in which (i,j) is - group_id = ij // size_gj + group_id = ij // size_gj # row-index of the first element of this group off_i = group_id * size_g # last group may have fewer rows - size_g = minimum(size_i - off_i, size_g) + size_g = minimum(size_i - off_i, size_g) # new row and column indices new_i = off_i + (ij % size_g) new_j = (ij % size_gj) // size_g diff --git a/python/triton/language/random.py b/python/triton/language/random.py index cb2ddfc6b..6f3645b41 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -1,7 +1,6 @@ import triton from . import core as tl - PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index ca6ca61f8..7d27ffd20 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,4 +1,4 @@ #from .conv import _conv, conv -from .matmul import _matmul, matmul +from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy -from . import blocksparse \ No newline at end of file +from .matmul import _matmul, matmul diff --git a/python/triton/ops/blocksparse/__init__.py b/python/triton/ops/blocksparse/__init__.py index c8da856aa..231c27a1f 100644 --- a/python/triton/ops/blocksparse/__init__.py +++ b/python/triton/ops/blocksparse/__init__.py @@ -1,2 +1,2 @@ from .matmul import matmul -from .softmax import softmax \ No newline at end of file +from .softmax import softmax diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index ce15c9af4..15e6c0523 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,6 +1,7 @@ +import torch + import triton import triton.language as tl -import torch # ******************************************************** # -------------------------------------------------------- @@ -11,16 +12,17 @@ import torch # -------------------------------------------------------- # ******************************************************** + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @triton.jit def _sdd_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ak, - stride_zb, stride_hb, stride_bk, stride_nb, - stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, + A, B, C, + stride_za, stride_ha, stride_ma, stride_ak, + stride_zb, stride_hb, stride_bk, stride_nb, + stride_zc, stride_hc, stride_mc, stride_nc, + K, grid_offset, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): @@ -30,25 +32,25 @@ def _sdd_kernel( block_id = tl.program_id(1) + grid_offset lut += block_id * 3 # offsets - off_z = tl.program_id(2) # batch - off_h = tl.load(lut + 0) # head - + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + # initialize pointers to A start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + (off_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ak) + a_ptrs = A + (off_z * stride_za + + off_h * stride_ha + + offs_am[:, None] * stride_ma + + offs_ak[None, :] * stride_ak) # initialize pointers to B start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + (off_z * stride_zb \ - + off_h * stride_hb \ - + offs_bn[None, :] * stride_nb \ - + offs_bk[:, None] * stride_bk) + b_ptrs = B + (off_z * stride_zb + + off_h * stride_hb + + offs_bn[None, :] * stride_nb + + offs_bk[:, None] * stride_bk) ## ---------------- ## ## Inner Loop ## ## ---------------- ## @@ -69,13 +71,14 @@ def _sdd_kernel( ## ---------------- ## offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + (off_z * stride_zc \ - + block_id * stride_hc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc) + pc = C + (off_z * stride_zc + + block_id * stride_hc + + offs_cm[:, None] * stride_mc + + offs_cn[None, :] * stride_nc) tl.store(pc, c, mask=True) -def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), Ka, 0, lut, - TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, num_warps=4, ) return c @@ -119,50 +122,52 @@ def sdd_lut(layout, block, device): # This operation uses a look-up table that contains pre-computed pointer increments # in order to minimize computations in the inner loop of the matmul kernel. # ----------------------------- + + @triton.jit def _dsd_kernel( - A, B, C, - stride_az, stride_ha, stride_am, stride_ak, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, + A, B, C, + stride_az, stride_ha, stride_am, stride_ak, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_cm, stride_cn, + DS0, DS1, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): #------------# #- Prologue -# #------------# - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) - pidz = tl.program_id(2) + pidz = tl.program_id(2) header = lut + pid_n * 4 offset = tl.load(header + 0) - K = tl.load(header + 1) + K = tl.load(header + 1) column = tl.load(header + 2) - off_h = tl.load(header + 3) - pinc = lut + offset + off_h = tl.load(header + 3) + pinc = lut + offset # initialize pointers to A (sparse) - block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) # compiler hint + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint offs_am = tl.arange(0, TILE_M) offs_ak = tl.arange(0, TILE_K) pa = A + pidz * stride_az \ - + block_id * stride_ha \ - + offs_am[:, None] * stride_am \ - + offs_ak[None, :] * stride_ak + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak # initialize pointers to B (dense) - offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) start_bk = tl.load(pinc) start_bk = tl.multiple_of(start_bk, 8) # compiler hint - offs_bk = start_bk + tl.arange(0, TILE_K) + offs_bk = start_bk + tl.arange(0, TILE_K) pb = B + pidz * stride_zb \ - + off_h * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk ## ---------------- ## ## Inner Loop ## ## ---------------- ## @@ -177,7 +182,7 @@ def _dsd_kernel( b = tl.load(pb, mask=offs_bn[None, :] < DS0) acc += tl.dot(a, b) pa += inc_a - pb += inc_b*stride_bk + pb += inc_b * stride_bk pinc += 2 inc_a = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) @@ -185,15 +190,16 @@ def _dsd_kernel( inc_b = tl.multiple_of(inc_b, 8) c = acc.to(C.dtype.element_ty) # initialize pointers to C - offs_cm = column*TILE_M + tl.arange(0, TILE_M) - offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) pc = C + off_h * stride_hc \ - + pidz * stride_zc \ - + offs_cm[:, None] * stride_cm \ - + offs_cn[None, :] * stride_cn - tl.store(pc, c, mask = offs_cn[None, :] < DS0) + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) -def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -231,7 +237,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = # exit() return c -def dsd_lut(layout, block, step, trans, device): +def dsd_lut(layout, block, step, trans, device): sizes = torch.sum(layout, 2 if trans else 1) head_id, col_id = sizes.nonzero(as_tuple=True) sizes = sizes.flatten() @@ -313,11 +319,11 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- @triton.jit def _dds_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ka, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, + A, B, C, + stride_za, stride_ha, stride_ma, stride_ka, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_mc, stride_nc, + DS0, DS1, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): @@ -348,7 +354,7 @@ def _dds_kernel( + offs_ak[None, :] * stride_ka # initialize pointers to B (sparse) block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) + block_id = tl.multiple_of(block_id, 8) offs_bn = tl.arange(0, TILE_N) offs_bk = tl.arange(0, TILE_K) ptrs_b = B + pid_z * stride_zb \ @@ -429,7 +435,7 @@ class _matmul(torch.autograd.Function): @staticmethod def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, db_width, out ): c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) @@ -499,10 +505,10 @@ class matmul: def __call__(self, a, b, out = None): c = _matmul.apply( - a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, - self.c_lut, self.c_width, - self.da_lut, self.da_width, + a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, + self.c_lut, self.c_width, + self.da_lut, self.da_width, self.db_lut, self.db_width, out ) - return c \ No newline at end of file + return c diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index dcf77afc8..f9d49ae56 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -1,7 +1,8 @@ -import triton.language as tl -import triton import torch +import triton +import triton.language as tl + def num_warps(n): if n < 512: @@ -33,10 +34,10 @@ def _forward( check = rbn < size rbmn = tl.where(check, rbn, size - 1) # block id and column id - blockid = tl.load(LUT + offset + rbmn * 4 + 0) + blockid = tl.load(LUT + offset + rbmn * 4 + 0) columnid = tl.load(LUT + offset + rbmn * 4 + 1) - rowid = tl.load(LUT + offset + rbmn * 4 + 2) - headid = tl.load(LUT + offset + rbmn * 4 + 3) + rowid = tl.load(LUT + offset + rbmn * 4 + 2) + headid = tl.load(LUT + offset + rbmn * 4 + 3) # pointers to X px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn x = tl.load(px, mask=check, other=-float('inf')) @@ -64,7 +65,7 @@ def _forward( attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # apply causal mask - is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm + is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) # computation x = tl.softmax(x) @@ -127,9 +128,9 @@ class _softmax(torch.autograd.Function): @staticmethod def forward( - ctx, x, scale, rpe, - key_padding_mask, attn_mask, - kp_mask_mode, attn_mask_mode, + ctx, x, scale, rpe, + key_padding_mask, attn_mask, + kp_mask_mode, attn_mask_mode, is_causal, spdims, block, lut, maxlut ): @@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function): # run kernel M = x.shape[0] grid = [spdims[0] * spdims[1] * block, M] - _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - BLOCK = block, - APPLY_SCALE = apply_scale, - APPLY_RPE = apply_rpe, - APPLY_KP_MASK = apply_kp_mask, - APPLY_ATTN_MASK = apply_attn_mask, - KP_MASK_MUL = (kp_mask_mode == 'mul'), - ATTN_MASK_MUL = (attn_mask_mode == 'mul')) + _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0), + stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + BLOCK=block, + APPLY_SCALE=apply_scale, + APPLY_RPE=apply_rpe, + APPLY_KP_MASK=apply_kp_mask, + APPLY_ATTN_MASK=apply_attn_mask, + KP_MASK_MUL=(kp_mask_mode == 'mul'), + ATTN_MASK_MUL=(attn_mask_mode == 'mul')) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -211,10 +212,10 @@ class softmax: self.lut_cache = dict() def __call__( - self, x, scale=1., rpe=None, - key_padding_mask=None, attn_mask=None, + self, x, scale=1., rpe=None, + key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add', - is_causal = False + is_causal=False ): if rpe is not None and rpe.dtype != x.dtype: raise ValueError('relative position embedding must be %s' % x.dtype) @@ -224,11 +225,11 @@ class softmax: raise ValueError('Key padding mask must be %s' % x.dtype) lut, maxlut = self.make_lut(x.device) x = _softmax.apply( - x, scale, rpe, - key_padding_mask, attn_mask, - key_padding_mask_mode, attn_mask_mode, + x, scale, rpe, + key_padding_mask, attn_mask, + key_padding_mask_mode, attn_mask_mode, is_causal, - self.spdims, self.block, + self.spdims, self.block, lut, maxlut ) - return x \ No newline at end of file + return x diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 529b6c675..dfd4f4487 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -1,7 +1,9 @@ import os + +import torch + import triton import triton.language as tl -import torch def next_power_of_2(n): @@ -104,4 +106,4 @@ class _cross_entropy(torch.autograd.Function): return neg_logprobs, None -cross_entropy = _cross_entropy.apply \ No newline at end of file +cross_entropy = _cross_entropy.apply diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 8b7299a8b..60ecc9f3b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,11 +1,14 @@ import torch -import triton.language as tl + import triton +import triton.language as tl from .matmul_perf_model import * + def init_to_zero(name): return lambda nargs: nargs[name].zero_() + def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -14,14 +17,15 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @@ -30,26 +34,26 @@ def get_configs_io_bound(): # basic configs for compute-bound matmuls triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ] + get_configs_io_bound(), key=['M', 'N', 'K'], prune_configs_by={ - 'prune_num_stages_by' : prune_num_stages, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + 'prune_num_stages_by': prune_num_stages, + 'perf_model': estimate_matmul_time, + 'top_k': 10 }, ) @triton.jit -def _kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, +def _kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication @@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K, rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(K, 0, -BLOCK_K*SPLIT_K): + for k in range(K, 0, -BLOCK_K * SPLIT_K): if EVEN_K: a = tl.load(A) b = tl.load(B) @@ -117,10 +121,10 @@ class _matmul(torch.autograd.Function): c = torch.empty((M, N), device=device, dtype=a.dtype) # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), + _kernel[grid](a, b, c, M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), GROUP_M=8) return c diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 16667a7b1..af4f3eed8 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -1,116 +1,121 @@ +import heapq + import torch + import triton import triton._C.libtriton.triton as _triton from triton.testing import get_dram_gbps, get_max_tensorcore_tflops -import heapq + def get_tensorcore_tflops(backend, device, num_ctas, num_warps): - ''' return compute throughput in TOPS ''' - total_warps = num_ctas * min(num_warps, 4) - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) - return tflops + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) + return tflops + def estimate_matmul_time( - # backend, device, - num_warps, num_stages, - M, N, K, - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, - debug=False, **kwargs + # backend, device, + num_warps, num_stages, + M, N, K, + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, + debug=False, **kwargs ): - ''' return estimated running time in ms - = max(compute, loading) + store ''' - backend = _triton.runtime.backend.CUDA - device = torch.cuda.current_device() + ''' return estimated running time in ms + = max(compute, loading) + store ''' + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() - num_cta_m = triton.cdiv(M, BLOCK_M) - num_cta_n = triton.cdiv(N, BLOCK_N) - num_cta_k = SPLIT_K - num_ctas = num_cta_m * num_cta_n * num_cta_k + num_cta_m = triton.cdiv(M, BLOCK_M) + num_cta_n = triton.cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k - # If the input is smaller than the block size - M, N = max(M, BLOCK_M), max(N, BLOCK_N) + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) - # time to compute - total_ops = 2*M*N*K / (1024*1024*1024) # GOPS - tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) - compute_ms = total_ops / tput + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + compute_ms = total_ops / tput - # time to load data - num_sm = _triton.runtime.num_sm(backend, device) - active_cta_ratio = min(1, num_ctas/num_sm) - active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate - active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% - dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s - l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) - # assume 80% of (following) loads are in L2 cache - load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) - load_a_l2 = M*K*2*0.8*(num_cta_n-1) - load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) - load_b_l2 = N*K*2*0.8*(num_cta_m-1) - # total - total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB - total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) - # loading time in ms - load_ms = total_dram/dram_bw + total_l2/l2_bw + # time to load data + num_sm = _triton.runtime.num_sm(backend, device) + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2) + load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw - # estimate storing time - store_bw = dram_bw * 0.6 # :o - store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB - if SPLIT_K == 1: - store_ms = store_c_dram /store_bw - else: - reduce_bw = store_bw - store_ms = store_c_dram/reduce_bw - # c.zero_() - zero_ms = M*N*2/(1024*1024)/store_bw - store_ms += zero_ms + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms - total_time_ms = max(compute_ms, load_ms) + store_ms - if debug: - print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' - f'loading time: {load_ms}ms, store time: {store_ms}ms, ' - f'Activate CTAs: {active_cta_ratio*100}%') - return total_time_ms def prune_num_stages(configs): - backend = _triton.runtime.backend.CUDA - device = torch.cuda.current_device() - cc = _triton.runtime.cc(backend, device) - # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + cc = _triton.runtime.cc(backend, device) + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages - # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) - configs_map = {} - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ - kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages - - key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) - if key in configs_map: - configs_map[key].append((config, num_stages)) - else: - configs_map[key] = [(config, num_stages)] + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages - pruned_configs = [] - for k, v in configs_map.items(): - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k - if cc >= 80: - # compute cycles (only works for ampere GPUs) - mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16) - mma_cycles = mmas/min(4, num_warps) * 8 + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] - ldgsts_latency = 300 # Does this matter? - optimal_num_stages = ldgsts_latency/mma_cycles + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if cc >= 80: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 - # nearest stages, prefer large #stages - nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ - if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles - for n in nearest: - pruned_configs.append(n[0]) - else: # Volta & Turing only supports num_stages <= 2 - random_config = v[0][0] - random_config.num_stages = 2 - pruned_configs.append(random_config) - return pruned_configs \ No newline at end of file + # nearest stages, prefer large #stages + nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/python/triton/testing.py b/python/triton/testing.py index eef7f5be6..310e754ed 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,10 +1,11 @@ -import torch import os -import triton._C.libtriton.triton as _triton -from .code_gen import OutOfResources import subprocess import sys +import torch + +import triton._C.libtriton.triton as _triton +from .code_gen import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass @@ -13,6 +14,7 @@ except ImportError: _cutlass = None has_cutlass = False + def catch_oor(kernel, pytest_handle=None): try: res = kernel() @@ -42,11 +44,11 @@ def cutlass_matmul(a, b): c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) # run function dtype = str(a.dtype).split('.')[-1] - _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ - M, N, Ka,\ - a.stride(0), a.stride(1),\ - b.stride(0), b.stride(1),\ - c.stride(0), c.stride(1),\ + _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), + M, N, Ka, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), dtype, dtype, dtype, a.device.index, torch.cuda.current_stream(a.device).cuda_stream) @@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0): ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value return ret + def assert_almost_equal(x, y, decimal=2, err_msg=''): import numpy.testing as npt if isinstance(x, torch.Tensor): @@ -93,6 +96,7 @@ def nvsmi(attrs): ret = [int(x) for x in ret] return ret + def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0 torch.cuda.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat - n_warmup = max(1, int(warmup/estimate_ms)) - n_repeat = max(1, int(rep/estimate_ms)) + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 # doesn't contain any input data before the run start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') # Warm-up for _ in range(n_warmup): @@ -161,6 +165,7 @@ class Benchmark: """ This class is used by the :code:`perf_report` function to generate line plots with a concise API. """ + def __init__( self, x_names, @@ -224,9 +229,10 @@ class Mark: self.benchmarks = benchmarks def _run(self, bench, save_path, show_plots, print_data): + import os + import matplotlib.pyplot as plt import pandas as pd - import os y_mean = bench.line_names y_min = [f'{x}-min' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names] @@ -259,7 +265,7 @@ class Mark: xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) ax.set_xlabel(xlabel) ax.set_ylabel(bench.ylabel) - #ax.set_title(bench.plot_name) + # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -297,6 +303,7 @@ def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) return wrapper + def get_dram_gbps(backend=None, device=None): ''' return DRAM bandwidth in GB/s ''' # assert backend == CUDA @@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None): device = torch.cuda.current_device() mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device) - bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s return bw_gbps + def get_max_tensorcore_tflops(backend, device): - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz # assume fp32 += fp16*fp16 cc = _triton.runtime.cc(backend, device) if cc < 80: - ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: ops_per_sub_core = 512 - tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) + tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024) return tflops diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index fbbfa6d0b..3b443c690 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -21,8 +21,8 @@ # SOFTWARE. import argparse -import subprocess import re +import subprocess FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 0934c8ea1..c78ccabbc 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn # Compute Kernel # -------------------------- -from triton.language.core import constexpr import torch + import triton import triton.language as tl @@ -38,7 +38,7 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case + # Load x and y from DRAM, masking out any extra elements in case # the input is not a multiple of the block size x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 2c0cfb9a8..30e507b0d 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -16,6 +16,8 @@ You will learn about: # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Let us consider instead the case of a simple (numerically stabilized) softmax operation: +import triton.language as tl +import triton import torch @@ -59,13 +61,10 @@ def naive_softmax(x): # power-of-two number of elements, so we need to internally "pad" each row and guard the # memory operations properly if we want to handle any possible input shapes: -import triton -import triton.language as tl - @triton.jit def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, + output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr ): # The rows of the softmax are independent, so we parallelize across those @@ -136,7 +135,7 @@ y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) print(torch.allclose(y_triton, y_torch)) -#%% +# %% # As expected, the results are identical. # %% @@ -187,5 +186,5 @@ benchmark.run(show_plots=True, print_data=True) # In the above plot, we can see that: # # - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. -# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. # Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape. diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 2d2ab91e9..240583df2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -112,13 +112,13 @@ You will specifically learn about: # # number of programs ids along the N axis # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # # number of programs in group -# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# num_pid_in_group = GROUP_SIZE_M * num_pid_n # # id of the group this program is in -# group_id = pid // num_pid_in_group +# group_id = pid // num_pid_in_group # # row-id of the first program in the group -# first_pid_m = group_id * GROUP_SIZE_M +# first_pid_m = group_id * GROUP_SIZE_M # # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller -# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # # *within groups*, programs are ordered in a column-major order # # row-id of the program in the *launch grid* # pid_m = first_pid_m + (pid % group_size_m) @@ -141,6 +141,7 @@ You will specifically learn about: # import torch + import triton import triton.language as tl @@ -152,18 +153,19 @@ import triton.language as tl # - An autotuning *key* whose change in values will trigger evaluation of all the # provided configs + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ], key=['M', 'N', 'K'], ) @@ -185,7 +187,7 @@ def matmul_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, - ): +): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ @@ -196,16 +198,16 @@ def matmul_kernel( pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction + # We will advance this pointer as we move in the K direction # and accumulate # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers @@ -213,8 +215,8 @@ def matmul_kernel( offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) - b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix @@ -223,8 +225,8 @@ def matmul_kernel( # `accumulator` will be converted back to fp16 after the loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. - # This means that if K is not a multiple of BLOCK_SIZE_K, + # Note that for simplicity, we don't apply a mask here. + # This means that if K is not a multiple of BLOCK_SIZE_K, # this will access out-of-bounds memory and produce an # error or (worse!) incorrect results. a = tl.load(a_ptrs) @@ -236,7 +238,7 @@ def matmul_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32 ! - if meta['ACTIVATION']: + if meta['ACTIVATION']: accumulator = meta['ACTIVATION'](accumulator) c = accumulator.to(tl.float16) diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index d988746a7..5c4f53435 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -13,7 +13,7 @@ whose state is generally composed of a bit mask tensor of the same shape as the # %% # Baseline # ------------- -# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance +# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance # of deep neural networks in low-data regime (i.e. regularization). # # It takes a vector as input and produces a vector of the same shape as output. Each scalar in the @@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the import tabulate import torch + import triton import triton.language as tl + @triton.jit def _dropout( - x_ptr, # pointer to the input - x_keep_ptr, # pointer to a mask of 0s and 1s - output_ptr, # pointer to the output - n_elements, # number of elements in the `x` tensor - p, # probability that an element of `x` is changed to zero + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero **meta, ): BLOCK_SIZE = meta['BLOCK_SIZE'] @@ -64,6 +66,7 @@ def dropout(x, x_keep, p): _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output + # Input tensor x = torch.randn(size=(10,)).cuda() # Dropout mask @@ -88,7 +91,7 @@ print(tabulate.tabulate([ # of persisting randomness across multiple invocations of the kernel. # # Pseudorandom number generation in Triton is simple! In this tutorial we will use the -# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` +# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` # values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides # other :ref:`random number generation strategies `. # @@ -97,6 +100,7 @@ print(tabulate.tabulate([ # # Let's put it all together. + @triton.jit def _seeded_dropout( x_ptr, diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index ffad17f50..82231e15c 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -4,15 +4,17 @@ Layer Normalization """ import torch -import triton.language as tl + import triton +import triton.language as tl + # Forward Pass @triton.jit def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # position of elements processed by this program - row = tl.program_id(0) + row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) mask = cols < N # offset data pointers to start at the row of interest @@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): mean = tl.sum(x, axis=0) / N # compute std xmean = tl.where(mask, x - mean, 0.) - var = tl.sum(xmean * xmean, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - xhat = xmean*rstd + var = tl.sum(xmean * xmean, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + xhat = xmean * rstd # write-back mean/rstd tl.store(M + row, mean) tl.store(V + row, rstd) @@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): # Backward pass (DX + partial DW + partial DB) @triton.jit def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, - stride, N, eps, - **META): + stride, N, eps, + **META): GROUP_SIZE_M = META['GROUP_SIZE_M'] BLOCK_SIZE_N = META['BLOCK_SIZE_N'] # position of elements processed by this program - row = tl.program_id(0) + row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols < N # offset data pointers to start at the row of interest - X += row * stride + X += row * stride DY += row * stride DX += row * stride # offset locks and weight/bias gradient pointer @@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, # these buffers stay in the L2, which allow this kernel # to be fast lock_id = row % GROUP_SIZE_M - Lock += lock_id - Count = Lock + GROUP_SIZE_M - DW = DW + lock_id*N + cols - DB = DB + lock_id*N + cols + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols # load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - w = tl.load(W + cols, mask=mask).to(tl.float32) - mean = tl.load(M + row) - rstd = tl.load(V + row) + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + mean = tl.load(M + row) + rstd = tl.load(V + row) # compute dx - xhat = (x - mean)*rstd - wdy = w * dy - xhat = tl.where(mask, xhat, 0.) - wdy = tl.where(mask, wdy , 0.) + xhat = (x - mean) * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat*mean1 + mean2))*rstd + dx = (wdy - (xhat * mean1 + mean2)) * rstd # write-back dx tl.store(DX + cols, dx, mask=mask) # accumulate partial sums for dw/db - partial_dw = (dy*xhat).to(w.dtype) + partial_dw = (dy * xhat).to(w.dtype) partial_db = (dy).to(w.dtype) while tl.atomic_cas(Lock, 0, 1) == 1: pass @@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) + + @triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): pid = tl.program_id(0) BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] - cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None]*N + cols[None, :] + offs = rows[:, None] * N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) - tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, + _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, + x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE - ctx.num_warps = num_warps - ctx.eps = eps + ctx.num_warps = num_warps + ctx.eps = eps return y @staticmethod @@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function): if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output - locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB @@ -166,14 +171,14 @@ class LayerNorm(torch.autograd.Function): M, N = x_arg.shape _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, + BLOCK_SIZE_N=ctx.BLOCK_SIZE, GROUP_SIZE_M=GROUP_SIZE_M, num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M = 32, - BLOCK_SIZE_N = 128) + _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=128) return dx, None, dw, db, None @@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') - dy = .1*torch.randn_like(x) + weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + dy = .1 * torch.randn_like(x) x.requires_grad_(True) # forward pass y_tri = layer_norm(x, w_shape, weight, bias, eps) @@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) + @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], @@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} ) ) -def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') - dy = .1*torch.randn_like(x) + weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + dy = .1 * torch.randn_like(x) x.requires_grad_(True) # utility functions if provider == 'triton': @@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu y_fwd = lambda: apex_layer_norm(x) # forward pass if mode == 'forward': - gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 + gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) # backward pass if mode == 'backward': - gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 + gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 y = y_fwd() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) + bench_layer_norm.run(save_path='.', print_data=True)