[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton import triton
# ------------------------------- # -------------------------------
@@ -8,18 +9,18 @@ import triton
nt = {False: 'n', True: 't'} nt = {False: 'n', True: 't'}
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names=['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64, 128], line_vals=[16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'], line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS', ylabel='TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode, args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for AT in [False] for BT in [False] \ for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense'] for op_mode in ['dsd'] for layout_mode in ['dense']
] ]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1 Z, H = 1, 1
make_layout = { make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode] }[layout_mode]
# create layout # create layout
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = { num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ 'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ 'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block 'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12 }[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms) return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N'], x_names=['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64], line_vals=[16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'], line_names=['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{layout_mode}-square', plot_name=f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for layout_mode in ['dense', 'tril'] for layout_mode in ['dense', 'tril']
] ]
@@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
return gbps(mean_ms), gbps(min_ms), gbps(max_ms) return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True) bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -1,17 +1,18 @@
import torch import torch
import triton import triton
confs = [ confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['N'], x_names=['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider', line_arg='provider',
line_vals = ['triton', 'torch'], line_vals=['triton', 'torch'],
line_names = ['Triton', 'Torch'], line_names=['Triton', 'Torch'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{mode}-2048', plot_name=f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\ )
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
@@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9) num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3 gbps = lambda ms: num_gb / ms * 1e3
# forward pass # forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider] 'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward': if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward': if mode == 'backward':
@@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider):
if __name__ == '__main__': if __name__ == '__main__':
bench_op.run(print_data=True) bench_op.run(print_data=True)

View File

@@ -1,6 +1,6 @@
import triton
import torch import torch
import os
import triton
def rounded_linspace(low, high, steps, div): def rounded_linspace(low, high, steps, div):
@@ -29,16 +29,16 @@ square_confs = [
transformer_confs = [ transformer_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=[x], x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128), x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider", line_arg="provider",
line_vals=["cublas", "triton", "cutlass"], line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"], line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\ ) for NK in [12288]
for i, x in enumerate(["N", "K"])\ for i, x in enumerate(["N", "K"])
for M in [2048] for M in [2048]
] ]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t() if AT:
if BT: b = b.t() a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9 tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas": if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try: try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms) return tflops(ms), tflops(max_ms), tflops(min_ms)
except: except Exception:
return None return None
return None return None

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import sys
import os
import inspect import inspect
import os
import sys
import triton import triton

View File

@@ -1,29 +1,28 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn import distutils.spawn
import urllib.request import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm(): def get_llvm():
# tries to find system LLVM # tries to find system LLVM
versions = ['-11.0', '-11', '-11-64'] versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions] supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None] paths = [p for p in paths if p is not None]
if paths: if paths:
return '', '' return '', ''
# download if nothing is installed # download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp' dir = '/tmp'
@@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir): if not os.path.exists(llvm_library_dir):
try: try:
shutil.rmtree(os.path.join(dir, name)) shutil.rmtree(os.path.join(dir, name))
except: except Exception:
pass pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...') print('downloading and extracting ' + url + '...')
@@ -96,7 +95,7 @@ class CMakeBuild(build_ext):
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
] ]

View File

@@ -1,14 +1,18 @@
from numpy import record import triton.language as tl
import torch
import triton
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch
from numpy import record
import triton
####################### #######################
# Utilities # Utilities
####################### #######################
def nvsmi(attrs): def nvsmi(attrs):
attrs = ','.join(attrs) attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -23,48 +27,51 @@ def nvsmi(attrs):
####################### #######################
matmul_data = { matmul_data = {
# square # square
(256 , 256 , 256 ) : {'v100': 0.027}, (256, 256, 256): {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.158}, (512, 512, 512): {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466}, (1024, 1024, 1024): {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680}, (2048, 2048, 2048): {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831}, (4096, 4096, 4096): {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.849}, (8192, 8192, 8192): {'v100': 0.849},
# tall-skinny # tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128}, (16, 1024, 1024): {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0883}, (16, 4096, 4096): {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101}, (16, 8192, 8192): {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073}, (64, 1024, 1024): {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.270}, (64, 4096, 4096): {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360}, (64, 8192, 8192): {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692}, (1024, 64, 1024): {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.264}, (4096, 64, 4096): {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323}, (8192, 64, 8192): {'v100': 0.323},
# # deep reductions # # deep reductions
# (64 , 64 , 16384) : {'v100': 0.}, # (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.}, # (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.}, # (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.}, # (256 , 256 , 32768) : {'v100': 0.},
} }
@pytest.mark.parametrize('M, N, K', matmul_data.keys()) @pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K): def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100'] ref_gpu_util = matmul_data[(M, N, K)]['v100']
cur_sm_clock = nvsmi(['clocks.current.sm'])[0] cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350 ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda') a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b) fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9 cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
####################### #######################
# Element-Wise # Element-Wise
####################### #######################
import triton.language as tl
@triton.jit @triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, def _add(x_ptr, y_ptr, output_ptr, n_elements,
@@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = { elementwise_data = {
1024*16 : {'v100': 0.0219}, 1024 * 16: {'v100': 0.0219},
1024*64 : {'v100': 0.0791}, 1024 * 64: {'v100': 0.0791},
1024*256 : {'v100': 0.243}, 1024 * 256: {'v100': 0.243},
1024*1024 : {'v100': 0.534}, 1024 * 1024: {'v100': 0.534},
1024*4096 : {'v100': 0.796}, 1024 * 4096: {'v100': 0.796},
1024*16384: {'v100': 0.905}, 1024 * 16384: {'v100': 0.905},
1024*65536: {'v100': 0.939}, 1024 * 65536: {'v100': 0.939},
} }
@pytest.mark.parametrize('N', elementwise_data.keys()) @pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N): def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100'] ref_gpu_util = elementwise_data[N]['v100']
cur_mem_clock = nvsmi(['clocks.current.memory'])[0] cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877 ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3 max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda') z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z) x = torch.randn_like(z)
@@ -102,7 +110,6 @@ def test_elementwise(N):
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, SIZE: tl.constexpr): def kernel(X, SIZE: tl.constexpr):
pass pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, SIZE: tl.constexpr): def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr): def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# --------------- # ---------------
# test binary ops # test binary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes for dtype_x in dtypes
for dtype_y in dtypes for dtype_y in dtypes
]) ])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y' expr = f' x {op} y'
@@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y", @pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
) )
def test_floordiv(dtype_x, dtype_y, device='cuda'): def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry # Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a # through to //, so we have to use a nonstandard expression to get a
@@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops # test compare ops
# --------------- # ---------------
ops = ['==', '!=', '>', '<', '>=', '<='] ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', op, mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
('nan', 'nan')]
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y' expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin' # 'exp', 'log', 'cos', 'sin'
# ]) # ])
@pytest.mark.parametrize("expr", [ @pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin' 'exp', 'log', 'cos', 'sin'
]) ])
@@ -368,8 +376,8 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [ @pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d) (f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16'] for d in ['int32', 'uint32', 'uint16']
]) ])
def test_index1d(expr, dtype_str, device='cuda'): def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':') rank_x = expr.count(':')
@@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
@triton.jit @triton.jit
def fn(a, b): def fn(a, b):
return a + b, \ return a + b, \
a - b, \ a - b, \
a * b a * b
def test_tuples(): def test_tuples():
@@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) (dtype_x, dtype_z, False)
for dtype_x in dtypes for dtype_x in dtypes
for dtype_z in dtypes for dtype_z in dtypes
] + [ ] + [
('float32', 'bfloat16', False), ('float32', 'bfloat16', False),
('bfloat16', 'float32', False), ('bfloat16', 'float32', False),
@@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
@triton.jit @triton.jit
def kernel(X, Z, BITCAST: tl.constexpr): def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X) x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast = BITCAST) z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z) tl.store(Z, z)
# triton result # triton result
@@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# --------------- # ---------------
# test reduce # test reduce
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape", @pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \ [(dtype, shape)
for dtype in dtypes\ for dtype in dtypes
for shape in [128, 512]]) for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'): def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel # triton kernel
@@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M) range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N) range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS) z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z) tl.store(Z + range_m, z)
# input # input
@@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# --------------- # ---------------
# test permute # test permute
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \ [(dtype, shape, perm)
for dtype in ['float32']\ for dtype in ['float32']
for shape in [(128, 128)]\ for shape in [(128, 128)]
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
# triton kernel # triton kernel
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot # test dot
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'): def test_dot(epilogue, device='cuda'):
# triton kernel # triton kernel
@@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix', ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS = epilogue=='add-rows', ADD_ROWS=epilogue == 'add-rows',
ADD_COLS = epilogue=='add-cols') ADD_COLS=epilogue == 'add-cols')
# torch result # torch result
z_ref = np.matmul(x, y) z_ref = np.matmul(x, y)
if epilogue == 'add-matrix': if epilogue == 'add-matrix':
z_ref += z z_ref += z
if epilogue == 'add-rows': if epilogue == 'add-rows':
z_ref += z[:,0][:, None] z_ref += z[:, 0][:, None]
if epilogue == 'add-cols': if epilogue == 'add-cols':
z_ref += z[0,:][None, :] z_ref += z[0, :][None, :]
# compare # compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized # make sure ld/st are vectorized
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
def test_dot_without_load(): def test_dot_without_load():
@triton.jit @triton.jit
def kernel(out): def kernel(out):
@@ -713,28 +727,30 @@ def test_dot_without_load():
b = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32) c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b) c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c) tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda") out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out) kernel[(1,)](out)
# --------------- # ---------------
# test arange # test arange
# --------------- # ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'): def test_arange(start, device='cuda'):
BLOCK = 128 BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit @triton.jit
def _kernel(z, BLOCK: tl.constexpr, def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr): START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK) off = tl.arange(0, BLOCK)
val = tl.arange(START, END) val = tl.arange(START, END)
tl.store(z + off, val) tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
# --------------- # ---------------
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# --------------- # ---------------
# 'bfloat16': torch.bfloat16, # 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run. # Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32 M = 32
@@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N) N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K) K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs. # Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w) o = tl.dot(x, w)
# Store output # Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out, pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0], in1.stride()[0],
in2.stride()[0], in2.stride()[0],
out.stride()[0], out.stride()[0],
in1.numel(), in1.numel(),
in2.numel(), in2.numel(),
out.numel(), out.numel(),
M=M, N=N, K=K) M=M, N=N, K=K)
reference_out =torch.matmul(in1, in2) reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out) triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache): def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda') src = torch.empty(128, device='cuda')
@@ -796,8 +815,8 @@ def test_load_cache_modifier(cache):
@triton.jit @triton.jit
def _kernel(dst, src, CACHE: tl.constexpr): def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128) offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=CACHE) x = tl.load(src + offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x) tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache) pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
@@ -830,11 +849,14 @@ def test_load_cache_modifier(cache):
# --------------- # ---------------
# test default # test default
# --------------- # ---------------
#TODO: can't be local to test_default # TODO: can't be local to test_default
@triton.jit @triton.jit
def _impl(value = 10): def _impl(value=10):
return value return value
def test_default(): def test_default():
value = 5 value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -851,7 +873,9 @@ def test_default():
# --------------- # ---------------
# test noop # test noop
#---------------- # ----------------
def test_noop(device='cuda'): def test_noop(device='cuda'):
@triton.jit @triton.jit
def kernel(x): def kernel(x):
@@ -861,9 +885,9 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [ @pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
]) ])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None: def test_value_specialization(value: int, value_type: str, device='cuda') -> None:

View File

@@ -1,16 +1,17 @@
import torch import numpy as np
import triton
import triton.language as tl
import pytest import pytest
import scipy.stats import scipy.stats
import numpy as np import torch
from numpy.random import Philox from numpy.random import Philox
import triton
import triton.language as tl
##################################### #####################################
## Reference Philox Implementation # Reference Philox Implementation
##################################### #####################################
class PhiloxConfig: class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
##################################### #####################################
## Unit Tests # Unit Tests
##################################### #####################################
BLOCK = 1024 BLOCK = 1024
# test generation of random uint32 # test generation of random uint32
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\ [(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
) )
def test_randint(size, seed, device='cuda'): def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(','))) size = list(map(int, size.split(',')))
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref assert out_tri == out_ref
# test uniform PRNG # test uniform PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_rand(size, seed, device='cuda'): def test_rand(size, seed, device='cuda'):
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG # test normal PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_randn(size, seed, device='cuda'): def test_randn(size, seed, device='cuda'):
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result # torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx # broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask) if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx) M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf") rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :] # rx += kp_mask[:, None, None, :]

View File

@@ -1,14 +1,16 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode", @pytest.mark.parametrize("M, N, dtype, mode",
[ [
(M, N, dtype, mode) for M in [1024, 821] (M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000] for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\ for dtype in ['float16', 'float32']
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
) )
def test_op(M, N, dtype, mode): def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
@@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode):
x.grad.zero_() x.grad.zero_()
th_y.backward(dy) th_y.backward(dy)
th_dx = x.grad.clone() th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx) triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -1,8 +1,10 @@
import pytest
import itertools import itertools
import triton
import pytest
import torch import torch
import triton
@pytest.mark.parametrize( @pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
@@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
K = BLOCK_K * SPLIT_K if K is None else K K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs # allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a a = a.t() if AT else a
b = b.t() if BT else b b = b.t() if BT else b
# run test # run test
th_c = torch.matmul(a, b) th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c) triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os import os
import shutil import shutil
import pytest import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp" tmpdir = ".tmp"
@triton.jit @triton.jit
def function_1(i): def function_1(i):
i = i + 1 i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1 i = i + 1
return i return i
@triton.jit @triton.jit
def kernel(X, i, BLOCK: tl.constexpr): def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
@triton.jit(do_not_specialize=["i"]) @triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr): def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
def apply_src_change(target, old, new): def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash') delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash') delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old) target.src = target.src.replace(new, old)
return ret return ret
def test_nochange(): def test_nochange():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1') updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated assert baseline == updated
def test_toplevel_change(): def test_toplevel_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2') updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def test_nested1_change(): def test_nested1_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2') updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def reset_tmp_dir(): def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir): if os.path.exists(tmpdir):
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def test_reuse(): def test_reuse():
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1
@@ -73,11 +85,12 @@ def test_reuse():
for i in range(10): for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024) kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1 assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable']) @pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode): def test_specialize(mode):
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess import subprocess
import triton.language as tl
import numpy as np import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix(): def get_p2p_matrix():

View File

@@ -1,26 +1,26 @@
import ast import ast
import builtins import builtins
import dbm
import functools import functools
import inspect
import struct
import sys
import textwrap
import hashlib import hashlib
import inspect
import os import os
import pickle import pickle
import struct
import subprocess import subprocess
import os import sys
import tempfile
import textwrap
import time
import warnings import warnings
from .tools.disasm import extract from typing import Dict, Optional
import torch import torch
from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from filelock import FileLock from .tools.disasm import extract
import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
@@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor):
arg_names, kwarg_names = self.visit(node.args) arg_names, kwarg_names = self.visit(node.args)
# initialize defaults # initialize defaults
for i, default_value in enumerate(node.args.defaults): for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1] arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation annotation = arg_node.annotation
name = arg_node.arg name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store()) st_target = ast.Name(id=name, ctx=ast.Store())
@@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor):
fn.args[idx].name = arg_name fn.args[idx].name = arg_name
arg_values.append(fn.args[idx]) arg_values.append(fn.args[idx])
idx += 1 idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
if inline: if inline:
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
# default: call visit_Assign # default: call visit_Assign
return self.visit_Assign(node) return self.visit_Assign(node)
def visit_Assign(self, node): def visit_Assign(self, node):
_names = [] _names = []
for target in node.targets: for target in node.targets:
@@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
if else_bb: if else_bb:
self.builder.set_insert_block(else_bb) self.builder.set_insert_block(else_bb)
is_terminator = self.visit_compound_statement(node.orelse) is_terminator = self.visit_compound_statement(node.orelse)
#TODO: last statement is a terminator? # TODO: last statement is a terminator?
if not is_terminator: if not is_terminator:
self.builder.br(endif_bb) self.builder.br(endif_bb)
self.module.seal_block(endif_bb) self.module.seal_block(endif_bb)
@@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),\ self.visit(pos_cond_node),
self.visit(neg_cond_node),\ self.visit(neg_cond_node),
_builder=self.builder) _builder=self.builder)
#cond_node = neg_cond_node #cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
@@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction): if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws) return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core: sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
return fn(*args, **kws) return fn(*args, **kws)
@@ -505,10 +503,10 @@ class Binary:
class LoadedBinary: class LoadedBinary:
def __init__(self, device: int, bin: Binary): def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend, module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name, bin.name,
bin.asm, bin.asm,
bin.shared_mem, bin.shared_mem,
device) device)
self.bin = bin self.bin = bin
self.asm = bin.asm self.asm = bin.asm
@@ -520,8 +518,8 @@ class LoadedBinary:
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel, _triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
grid_0, grid_1, grid_2, grid_0, grid_1, grid_2,
self.bin.num_warps * 32, 1, 1, self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem) args, self.bin.shared_mem)
def get_sass(self, fun=None): def get_sass(self, fun=None):
@@ -632,10 +630,14 @@ class Kernel:
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
if N % 16 == 0: return 16 if N % 16 == 0:
if N % 8 == 0: return 8 return 16
if N % 4 == 0: return 4 if N % 8 == 0:
if N % 2 == 0: return 2 return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1 return 1
def __init__(self, fn): def __init__(self, fn):
@@ -675,7 +677,7 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes # attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize} if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
@@ -705,7 +707,7 @@ class Kernel:
if binary is None: if binary is None:
binary = self._compile( binary = self._compile(
*wargs, device=device_idx, attributes=attributes, *wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, num_warps=num_warps, num_stages=num_stages,
constants=constants, constants=constants,
) )
if bin_cache_path: if bin_cache_path:
@@ -766,13 +768,12 @@ class Launcher:
def __call__(self, *wargs, **kwargs): def __call__(self, *wargs, **kwargs):
return self.kernel(*wargs, **kwargs, grid=self.grid) return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner: class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
''' '''
:param prune_configs_by: a dict of functions that are used to prune configs, fields: :param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time 'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench 'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
@@ -788,6 +789,7 @@ class Autotuner:
self.hook = lambda args: 0 self.hook = lambda args: 0
if reset_to_zero is not None: if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero] self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args): def _hook(args):
for i in self.reset_idx: for i in self.reset_idx:
args[i].zero_() args[i].zero_()
@@ -802,7 +804,7 @@ class Autotuner:
perf_model, top_k, prune_num_stages_by = None, None, None perf_model, top_k, prune_num_stages_by = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by self.prune_num_stages_by = prune_num_stages_by
def _bench(self, *args, config, **meta): def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided # check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner # as kwargs and by the autotuner
@@ -814,6 +816,7 @@ class Autotuner:
) )
# augment meta-parameters with tunable ones # augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs) current = dict(meta, **config.kwargs)
def kernel_call(): def kernel_call():
if config.pre_hook: if config.pre_hook:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
@@ -836,9 +839,9 @@ class Autotuner:
top_k = int(len(self.configs) * top_k) top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k: if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time() bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \ timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs} for config in pruned_configs}
bench_end = time.time() bench_end = time.time()
self.bench_time = bench_end - bench_start self.bench_time = bench_end - bench_start
@@ -876,7 +879,7 @@ def version_key():
ptxas_version = '' ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3 # 3
class DependenciesFinder(ast.NodeVisitor): class DependenciesFinder(ast.NodeVisitor):
@@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
return self.globals.get(node.id, None) return self.globals.get(node.id, None)
def visit_Attribute(self, node): def visit_Attribute(self, node):
lhs = self.visit(node.value) lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute): while isinstance(lhs, ast.Attribute):
@@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor):
self.ret = (self.ret + func.hash).encode("utf-8") self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest() self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction:
cache_hook = None
class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None): def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function # information of wrapped function
@@ -946,7 +949,6 @@ class JITFunction:
# forward docs # forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
@property @property
@functools.lru_cache() @functools.lru_cache()
def cache_key(self): def cache_key(self):
@@ -1027,6 +1029,7 @@ class Config:
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args. function are args.
""" """
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs self.kwargs = kwargs
self.num_warps = num_warps self.num_warps = num_warps
@@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
.. highlight:: python .. highlight:: python
.. code-block:: python .. code-block:: python
@triton.autotune(configs=[ @triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
], ],
key=['x_size'] # the two above configs will be evaluated anytime key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes # the value of x_size changes
) )
@triton.jit @triton.jit
def kernel(x_ptr, x_size, **META): def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time. :note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times. This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration. reset the value of the provided tensor to `zero` before running any configuration.
@@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:type configs: list[triton.Config] :type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str] :type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields: :param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time 'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench 'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
@@ -1099,7 +1102,7 @@ def heuristics(values):
def kernel(x_ptr, x_size, **META): def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input. each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]] .type values: dict[str, Callable[[list[Any]], Any]]
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
def cdiv(x, y): def cdiv(x, y):
return (x + y - 1) // y return (x + y - 1) // y
def next_power_of_2(n): def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n""" """Return the smallest power of 2 greater than or equal to n"""
n -= 1 n -= 1
@@ -1163,13 +1167,14 @@ def next_power_of_2(n):
###### ######
class TensorWrapper: class TensorWrapper:
def __init__(self, base, dtype): def __init__(self, base, dtype):
self.dtype = dtype self.dtype = dtype
self.base = base self.base = base
self.is_cuda = base.is_cuda self.is_cuda = base.is_cuda
self.device = base.device self.device = base.device
def data_ptr(self): def data_ptr(self):
return self.base.data_ptr() return self.base.data_ptr()

View File

@@ -1,8 +1,8 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps from functools import wraps
import triton
from triton._C.libtriton.triton import frontend, ir
# convert block/dtype to ir values # convert block/dtype to ir values
def _to_ir(x, builder): def _to_ir(x, builder):
@@ -65,7 +65,7 @@ def builtin(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \ if '_builder' not in kwargs or \
kwargs['_builder'] is None: kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs) return fn(*args, **kwargs)
return wrapper return wrapper
@@ -111,6 +111,7 @@ class pointer_dtype:
def __str__(self): def __str__(self):
return f'pointer<{self.element_ty}>' return f'pointer<{self.element_ty}>'
# scalar types # scalar types
int1 = dtype(ir.type.get_int1) int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8) int8 = dtype(ir.type.get_int8)
@@ -331,27 +332,27 @@ class constexpr:
def __rsub__(self, other): def __rsub__(self, other):
return other.value - self.value return other.value - self.value
def __mul__(self, other): def __mul__(self, other):
return self.value * other.value return self.value * other.value
def __rmul__(self, other): def __rmul__(self, other):
return other.value * self.value return other.value * self.value
def __truediv__(self, other): def __truediv__(self, other):
return self.value / other.value return self.value / other.value
def __rtruediv__(self, other): def __rtruediv__(self, other):
return other.value / self.value return other.value / self.value
def __floordiv__(self, other): def __floordiv__(self, other):
return self.value // other.value return self.value // other.value
def __rfloordiv__(self, other): def __rfloordiv__(self, other):
return other.value // self.value return other.value // self.value
# #
def __gt__(self, other): def __gt__(self, other):
return self.value > other.value return self.value > other.value
@@ -360,25 +361,25 @@ class constexpr:
def __ge__(self, other): def __ge__(self, other):
return self.value >= other.value return self.value >= other.value
def __rge__(self, other): def __rge__(self, other):
return other.value >= self.value return other.value >= self.value
def __lt__(self, other): def __lt__(self, other):
return self.value < other.value return self.value < other.value
def __rlt__(self, other): def __rlt__(self, other):
return other.value < self.value return other.value < self.value
def __le__(self, other): def __le__(self, other):
return self.value <= other.value return self.value <= other.value
def __rle__(self, other): def __rle__(self, other):
return other.value <= self.value return other.value <= self.value
def __eq__(self, other): def __eq__(self, other):
return self.value == other.value return self.value == other.value
def __ne__(self, other): def __ne__(self, other):
return self.value != other.value return self.value != other.value
@@ -489,15 +490,16 @@ def broadcast_to(input, shape, _builder=None):
""" """
return frontend.broadcast_to(input, shape, _builder) return frontend.broadcast_to(input, shape, _builder)
@builtin @builtin
def cat(input, other, _builder=None): def cat(input, other, _builder=None):
""" """
Concatenate the given blocks Concatenate the given blocks
:param input: The first input block. :param input: The first input block.
:type input: :type input:
:param other: The second input block. :param other: The second input block.
:type other: :type other:
""" """
return frontend.cat(input, other, _builder) return frontend.cat(input, other, _builder)
@@ -508,7 +510,7 @@ def reshape(input, shape, _builder=None):
Tries to reshape the given block to a new shape. Tries to reshape the given block to a new shape.
:param input: The input block. :param input: The input block.
:type input: :type input:
:param shape: The desired shape. :param shape: The desired shape.
:type shape: Tuple[int] :type shape: Tuple[int]
@@ -546,7 +548,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
""" """
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`. :code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
@@ -565,7 +567,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
@builtin @builtin
def store(pointer, value, mask=None, _builder=None): def store(pointer, value, mask=None, _builder=None):
""" """
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
@@ -600,9 +602,10 @@ def _add_atomic_docstr(name):
""" """
func.__doc__ = docstr.format(name=name) func.__doc__ = docstr.format(name=name)
return func return func
return _decorator return _decorator
@builtin @builtin
@_add_atomic_docstr("compare-and-swap") @_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None): def atomic_cas(pointer, cmp, val, _builder=None):
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_xchg(pointer, val, mask=None, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder) return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin @builtin
@_add_atomic_docstr("add") @_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None): def atomic_add(pointer, val, mask=None, _builder=None):
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
def umulhi(x, y, _builder=None): def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder) return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name): def _add_math_1arg_docstr(name):
def _decorator(func): def _decorator(func):
@@ -694,24 +699,28 @@ def _add_math_1arg_docstr(name):
""" """
func.__doc__ = docstr.format(name=name) func.__doc__ = docstr.format(name=name)
return func return func
return _decorator return _decorator
@builtin @builtin
@_add_math_1arg_docstr("exponential") @_add_math_1arg_docstr("exponential")
def exp(x, _builder=None): def exp(x, _builder=None):
return frontend.exp(x, _builder) return frontend.exp(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("natural logarithm") @_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None): def log(x, _builder=None):
return frontend.log(x, _builder) return frontend.log(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("cosine") @_add_math_1arg_docstr("cosine")
def cos(x, _builder=None): def cos(x, _builder=None):
return frontend.cos(x, _builder) return frontend.cos(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("sine") @_add_math_1arg_docstr("sine")
def sin(x, _builder=None): def sin(x, _builder=None):
@@ -739,9 +748,10 @@ def _add_reduction_docstr(name):
""" """
func.__doc__ = docstr.format(name=name) func.__doc__ = docstr.format(name=name)
return func return func
return _decorator return _decorator
@builtin @builtin
@_add_reduction_docstr("maximum") @_add_reduction_docstr("maximum")
def max(input, axis, _builder=None): def max(input, axis, _builder=None):
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
def sum(input, axis, _builder=None): def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder) return frontend.sum(input, axis, _builder)
@builtin @builtin
@_add_reduction_docstr("xor sum") @_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None): def xor_sum(input, axis, _builder=None):
@@ -778,7 +789,7 @@ def debug_barrier(_builder=None):
@builtin @builtin
def multiple_of(input, value, _builder=None): def multiple_of(input, value, _builder=None):
""" """
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
""" """
return frontend.multiple_of(input, value, _builder) return frontend.multiple_of(input, value, _builder)
@@ -786,7 +797,7 @@ def multiple_of(input, value, _builder=None):
@builtin @builtin
def max_contiguous(input, value, _builder=None): def max_contiguous(input, value, _builder=None):
""" """
Let the compiler knows that the `value` first values in :code:`input` are contiguous. Let the compiler knows that the `value` first values in :code:`input` are contiguous.
""" """
return frontend.max_contiguous(input, value, _builder) return frontend.max_contiguous(input, value, _builder)
@@ -794,7 +805,7 @@ def max_contiguous(input, value, _builder=None):
@builtin @builtin
def max_contiguous(input, value, _builder=None): def max_contiguous(input, value, _builder=None):
""" """
Let the compiler knows that the `value` first values in :code:`input` are contiguous. Let the compiler knows that the `value` first values in :code:`input` are contiguous.
""" """
return frontend.max_contiguous(input, value, _builder) return frontend.max_contiguous(input, value, _builder)
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
def abs(x): def abs(x):
return where(x >= 0, x, -x) return where(x >= 0, x, -x)
@triton.jit @triton.jit
def cdiv(x, div): def cdiv(x, div):
""" """
@@ -871,13 +883,14 @@ def ravel(x):
""" """
return triton.language.reshape(x, [x.type.numel]) return triton.language.reshape(x, [x.type.numel])
@triton.jit @triton.jit
def swizzle2d(i, j, size_i, size_j, size_g): def swizzle2d(i, j, size_i, size_j, size_g):
""" """
transformes indices of a row-major size_i*size_j matrix into those transformes indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows. of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ], [[0 , 1 , 2 , 3 ],
[4 , 5 , 6 , 7 ], [4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11], [8 , 9 , 10, 11],
[12, 13, 14, 15]] [12, 13, 14, 15]]
@@ -888,16 +901,16 @@ def swizzle2d(i, j, size_i, size_j, size_g):
[9, 11, 13, 15]] [9, 11, 13, 15]]
""" """
# "unrolled index in array" # "unrolled index in array"
ij = i*size_j + j ij = i * size_j + j
# number of elements in `size_g` groups # number of elements in `size_g` groups
# of `size_j` columns # of `size_j` columns
size_gj = size_g * size_j size_gj = size_g * size_j
# index of the group in which (i,j) is # index of the group in which (i,j) is
group_id = ij // size_gj group_id = ij // size_gj
# row-index of the first element of this group # row-index of the first element of this group
off_i = group_id * size_g off_i = group_id * size_g
# last group may have fewer rows # last group may have fewer rows
size_g = minimum(size_i - off_i, size_g) size_g = minimum(size_i - off_i, size_g)
# new row and column indices # new row and column indices
new_i = off_i + (ij % size_g) new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g new_j = (ij % size_gj) // size_g

View File

@@ -1,7 +1,6 @@
import triton import triton
from . import core as tl from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53

View File

@@ -1,4 +1,4 @@
#from .conv import _conv, conv #from .conv import _conv, conv
from .matmul import _matmul, matmul from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse from .matmul import _matmul, matmul

View File

@@ -1,2 +1,2 @@
from .matmul import matmul from .matmul import matmul
from .softmax import softmax from .softmax import softmax

View File

@@ -1,6 +1,7 @@
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
# ******************************************************** # ********************************************************
# -------------------------------------------------------- # --------------------------------------------------------
@@ -11,16 +12,17 @@ import torch
# -------------------------------------------------------- # --------------------------------------------------------
# ******************************************************** # ********************************************************
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
}) })
@triton.jit @triton.jit
def _sdd_kernel( def _sdd_kernel(
A, B, C, A, B, C,
stride_za, stride_ha, stride_ma, stride_ak, stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb, stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc, stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut, K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr BLOCK: tl.constexpr, EVEN_K: tl.constexpr
): ):
@@ -30,25 +32,25 @@ def _sdd_kernel(
block_id = tl.program_id(1) + grid_offset block_id = tl.program_id(1) + grid_offset
lut += block_id * 3 lut += block_id * 3
# offsets # offsets
off_z = tl.program_id(2) # batch off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head off_h = tl.load(lut + 0) # head
# initialize pointers to A # initialize pointers to A
start_am = tl.load(lut + 1) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za \ a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha \ + off_h * stride_ha
+ offs_am[:, None] * stride_ma \ + offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak) + offs_ak[None, :] * stride_ak)
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(lut + 2) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb \ b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb \ + off_h * stride_hb
+ offs_bn[None, :] * stride_nb \ + offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk) + offs_bk[:, None] * stride_bk)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
@@ -69,13 +71,14 @@ def _sdd_kernel(
## ---------------- ## ## ---------------- ##
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc \ pc = C + (off_z * stride_zc
+ block_id * stride_hc \ + block_id * stride_hc
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc) + offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:
@@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3), c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut, Ka, 0, lut,
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4, num_warps=4,
) )
return c return c
@@ -119,50 +122,52 @@ def sdd_lut(layout, block, device):
# This operation uses a look-up table that contains pre-computed pointer increments # This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel. # in order to minimize computations in the inner loop of the matmul kernel.
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dsd_kernel( def _dsd_kernel(
A, B, C, A, B, C,
stride_az, stride_ha, stride_am, stride_ak, stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn, stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn, stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut, DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
): ):
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1) num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2) pidz = tl.program_id(2)
header = lut + pid_n * 4 header = lut + pid_n * 4
offset = tl.load(header + 0) offset = tl.load(header + 0)
K = tl.load(header + 1) K = tl.load(header + 1)
column = tl.load(header + 2) column = tl.load(header + 2)
off_h = tl.load(header + 3) off_h = tl.load(header + 3)
pinc = lut + offset pinc = lut + offset
# initialize pointers to A (sparse) # initialize pointers to A (sparse)
block_id = tl.load(pinc + 1) block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M) offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \ pa = A + pidz * stride_az \
+ block_id * stride_ha \ + block_id * stride_ha \
+ offs_am[:, None] * stride_am \ + offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak + offs_ak[None, :] * stride_ak
# initialize pointers to B (dense) # initialize pointers to B (dense)
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc) start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K) offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \ pb = B + pidz * stride_zb \
+ off_h * stride_hb \ + off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
@@ -177,7 +182,7 @@ def _dsd_kernel(
b = tl.load(pb, mask=offs_bn[None, :] < DS0) b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
pa += inc_a pa += inc_a
pb += inc_b*stride_bk pb += inc_b * stride_bk
pinc += 2 pinc += 2
inc_a = tl.load(pinc + 1) inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8) inc_a = tl.multiple_of(inc_a, 8)
@@ -185,15 +190,16 @@ def _dsd_kernel(
inc_b = tl.multiple_of(inc_b, 8) inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C # initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M) offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \ pc = C + off_h * stride_hc \
+ pidz * stride_zc \ + pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \ + offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0) tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:
@@ -231,7 +237,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
# exit() # exit()
return c return c
def dsd_lut(layout, block, step, trans, device): def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1) sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True) head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten() sizes = sizes.flatten()
@@ -313,11 +319,11 @@ def dsd_lut(layout, block, step, trans, device):
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dds_kernel( def _dds_kernel(
A, B, C, A, B, C,
stride_za, stride_ha, stride_ma, stride_ka, stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn, stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc, stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut, DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
): ):
@@ -348,7 +354,7 @@ def _dds_kernel(
+ offs_ak[None, :] * stride_ka + offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse) # initialize pointers to B (sparse)
block_id = tl.load(pinc + 1) block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N) offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \ ptrs_b = B + pid_z * stride_zb \
@@ -429,7 +435,7 @@ class _matmul(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out c_lut, c_width, da_lut, da_width, db_lut, db_width, out
): ):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
@@ -499,10 +505,10 @@ class matmul:
def __call__(self, a, b, out = None): def __call__(self, a, b, out = None):
c = _matmul.apply( c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width, self.c_lut, self.c_width,
self.da_lut, self.da_width, self.da_lut, self.da_width,
self.db_lut, self.db_width, self.db_lut, self.db_width,
out out
) )
return c return c

View File

@@ -1,7 +1,8 @@
import triton.language as tl
import triton
import torch import torch
import triton
import triton.language as tl
def num_warps(n): def num_warps(n):
if n < 512: if n < 512:
@@ -33,10 +34,10 @@ def _forward(
check = rbn < size check = rbn < size
rbmn = tl.where(check, rbn, size - 1) rbmn = tl.where(check, rbn, size - 1)
# block id and column id # block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0) blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1) columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2) rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3) headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X # pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf')) x = tl.load(px, mask=check, other=-float('inf'))
@@ -64,7 +65,7 @@ def _forward(
attn_m = tl.where(attn_m == 0, -float('inf'), 0.) attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m x = x + attn_m
# apply causal mask # apply causal mask
is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
# computation # computation
x = tl.softmax(x) x = tl.softmax(x)
@@ -127,9 +128,9 @@ class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, x, scale, rpe, ctx, x, scale, rpe,
key_padding_mask, attn_mask, key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode, kp_mask_mode, attn_mask_mode,
is_causal, is_causal,
spdims, block, lut, maxlut spdims, block, lut, maxlut
): ):
@@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M] grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK = block, BLOCK=block,
APPLY_SCALE = apply_scale, APPLY_SCALE=apply_scale,
APPLY_RPE = apply_rpe, APPLY_RPE=apply_rpe,
APPLY_KP_MASK = apply_kp_mask, APPLY_KP_MASK=apply_kp_mask,
APPLY_ATTN_MASK = apply_attn_mask, APPLY_ATTN_MASK=apply_attn_mask,
KP_MASK_MUL = (kp_mask_mode == 'mul'), KP_MASK_MUL=(kp_mask_mode == 'mul'),
ATTN_MASK_MUL = (attn_mask_mode == 'mul')) ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
ctx.save_for_backward(x, lut) ctx.save_for_backward(x, lut)
@@ -211,10 +212,10 @@ class softmax:
self.lut_cache = dict() self.lut_cache = dict()
def __call__( def __call__(
self, x, scale=1., rpe=None, self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None, key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add', key_padding_mask_mode='add', attn_mask_mode='add',
is_causal = False is_causal=False
): ):
if rpe is not None and rpe.dtype != x.dtype: if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype) raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -224,11 +225,11 @@ class softmax:
raise ValueError('Key padding mask must be %s' % x.dtype) raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device) lut, maxlut = self.make_lut(x.device)
x = _softmax.apply( x = _softmax.apply(
x, scale, rpe, x, scale, rpe,
key_padding_mask, attn_mask, key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode, key_padding_mask_mode, attn_mask_mode,
is_causal, is_causal,
self.spdims, self.block, self.spdims, self.block,
lut, maxlut lut, maxlut
) )
return x return x

View File

@@ -1,7 +1,9 @@
import os import os
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
def next_power_of_2(n): def next_power_of_2(n):
@@ -104,4 +106,4 @@ class _cross_entropy(torch.autograd.Function):
return neg_logprobs, None return neg_logprobs, None
cross_entropy = _cross_entropy.apply cross_entropy = _cross_entropy.apply

View File

@@ -1,11 +1,14 @@
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
from .matmul_perf_model import * from .matmul_perf_model import *
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
@@ -14,14 +17,15 @@ def get_configs_io_bound():
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps)) num_stages=num_stages, num_warps=num_warps))
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs return configs
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
}) })
@@ -30,26 +34,26 @@ def get_configs_io_bound():
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(), ] + get_configs_io_bound(),
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={ prune_configs_by={
'prune_num_stages_by' : prune_num_stages, 'prune_num_stages_by': prune_num_stages,
'perf_model': estimate_matmul_time, 'perf_model': estimate_matmul_time,
'top_k': 10 'top_k': 10
}, },
) )
@triton.jit @triton.jit
def _kernel(A, B, C, M, N, K, def _kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
# matrix multiplication # matrix multiplication
@@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K,
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K) rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers # pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K): for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K: if EVEN_K:
a = tl.load(A) a = tl.load(A)
b = tl.load(B) b = tl.load(B)
@@ -117,10 +121,10 @@ class _matmul(torch.autograd.Function):
c = torch.empty((M, N), device=device, dtype=a.dtype) c = torch.empty((M, N), device=device, dtype=a.dtype)
# launch kernel # launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K, _kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
GROUP_M=8) GROUP_M=8)
return c return c

View File

@@ -1,116 +1,121 @@
import heapq
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps): def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS ''' ''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4) total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops return tflops
def estimate_matmul_time( def estimate_matmul_time(
# backend, device, # backend, device,
num_warps, num_stages, num_warps, num_stages,
M, N, K, M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs debug=False, **kwargs
): ):
''' return estimated running time in ms ''' return estimated running time in ms
= max(compute, loading) + store ''' = max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N) num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size # If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N) M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute # time to compute
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput compute_ms = total_ops / tput
# time to load data # time to load data
num_sm = _triton.runtime.num_sm(backend, device) num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas/num_sm) active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache # assume 80% of (following) loads are in L2 cache
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
load_a_l2 = M*K*2*0.8*(num_cta_n-1) load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N*K*2*0.8*(num_cta_m-1) load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
# total # total
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms # loading time in ms
load_ms = total_dram/dram_bw + total_l2/l2_bw load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time # estimate storing time
store_bw = dram_bw * 0.6 # :o store_bw = dram_bw * 0.6 # :o
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1: if SPLIT_K == 1:
store_ms = store_c_dram /store_bw store_ms = store_c_dram / store_bw
else: else:
reduce_bw = store_bw reduce_bw = store_bw
store_ms = store_c_dram/reduce_bw store_ms = store_c_dram / reduce_bw
# c.zero_() # c.zero_()
zero_ms = M*N*2/(1024*1024)/store_bw zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs): def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {} configs_map = {}
for config in configs: for config in configs:
kw = config.kwargs kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
pruned_configs = [] key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
for k, v in configs_map.items(): if key in configs_map:
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k configs_map[key].append((config, num_stages))
if cc >= 80: else:
# compute cycles (only works for ampere GPUs) configs_map[key] = [(config, num_stages)]
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
mma_cycles = mmas/min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter? pruned_configs = []
optimal_num_stages = ldgsts_latency/mma_cycles for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8
# nearest stages, prefer large #stages ldgsts_latency = 300 # Does this matter?
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ optimal_num_stages = ldgsts_latency / mma_cycles
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest: # nearest stages, prefer large #stages
pruned_configs.append(n[0]) nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
else: # Volta & Turing only supports num_stages <= 2 if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
random_config = v[0][0]
random_config.num_stages = 2 for n in nearest:
pruned_configs.append(random_config) pruned_configs.append(n[0])
return pruned_configs else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs

View File

@@ -1,10 +1,11 @@
import torch
import os import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess import subprocess
import sys import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
try: try:
import triton._C.libtriton.cutlass as _cutlass import triton._C.libtriton.cutlass as _cutlass
@@ -13,6 +14,7 @@ except ImportError:
_cutlass = None _cutlass = None
has_cutlass = False has_cutlass = False
def catch_oor(kernel, pytest_handle=None): def catch_oor(kernel, pytest_handle=None):
try: try:
res = kernel() res = kernel()
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function # run function
dtype = str(a.dtype).split('.')[-1] dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,\ M, N, Ka,
a.stride(0), a.stride(1),\ a.stride(0), a.stride(1),
b.stride(0), b.stride(1),\ b.stride(0), b.stride(1),
c.stride(0), c.stride(1),\ c.stride(0), c.stride(1),
dtype, dtype, dtype, dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream) a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''): def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt import numpy.testing as npt
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@@ -93,6 +96,7 @@ def nvsmi(attrs):
ret = [int(x) for x in ret] ret = [int(x) for x in ret]
return ret return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
""" """
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
torch.cuda.synchronize() torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5 estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat # compute number of warmup and repeat
n_warmup = max(1, int(warmup/estimate_ms)) n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep/estimate_ms)) n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear # We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 # before each kernel call to make sure that the L2
# doesn't contain any input data before the run # doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up # Warm-up
for _ in range(n_warmup): for _ in range(n_warmup):
@@ -161,6 +165,7 @@ class Benchmark:
""" """
This class is used by the :code:`perf_report` function to generate line plots with a concise API. This class is used by the :code:`perf_report` function to generate line plots with a concise API.
""" """
def __init__( def __init__(
self, self,
x_names, x_names,
@@ -224,9 +229,10 @@ class Mark:
self.benchmarks = benchmarks self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data): def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os
y_mean = bench.line_names y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names] y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names]
@@ -259,7 +265,7 @@ class Mark:
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel) ax.set_ylabel(bench.ylabel)
#ax.set_title(bench.plot_name) # ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear") ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear") ax.set_yscale("log" if bench.y_log else "linear")
if show_plots: if show_plots:
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper
def get_dram_gbps(backend=None, device=None): def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s ''' ''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA # assert backend == CUDA
@@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None):
device = torch.cuda.current_device() device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps return bw_gbps
def get_max_tensorcore_tflops(backend, device): def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16 # assume fp32 += fp16*fp16
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
if cc < 80: if cc < 80:
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else: else:
ops_per_sub_core = 512 ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
return tflops return tflops

View File

@@ -21,8 +21,8 @@
# SOFTWARE. # SOFTWARE.
import argparse import argparse
import subprocess
import re import re
import subprocess
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')

View File

@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel # Compute Kernel
# -------------------------- # --------------------------
from triton.language.core import constexpr
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -38,7 +38,7 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE) offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses # Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case # Load x and y from DRAM, masking out any extra elements in case
# the input is not a multiple of the block size # the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask) x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask)

View File

@@ -16,6 +16,8 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation: # Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch import torch
@@ -59,13 +61,10 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the # power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes: # memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit @triton.jit
def softmax_kernel( def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr BLOCK_SIZE: tl.constexpr
): ):
# The rows of the softmax are independent, so we parallelize across those # The rows of the softmax are independent, so we parallelize across those
@@ -136,7 +135,7 @@ y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1) y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch)) print(torch.allclose(y_triton, y_torch))
#%% # %%
# As expected, the results are identical. # As expected, the results are identical.
# %% # %%
@@ -187,5 +186,5 @@ benchmark.run(show_plots=True, print_data=True)
# In the above plot, we can see that: # In the above plot, we can see that:
# #
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. # - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. # - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape. # Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.

View File

@@ -112,13 +112,13 @@ You will specifically learn about:
# # number of programs ids along the N axis # # number of programs ids along the N axis
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# # number of programs in group # # number of programs in group
# num_pid_in_group = GROUP_SIZE_M * num_pid_n # num_pid_in_group = GROUP_SIZE_M * num_pid_n
# # id of the group this program is in # # id of the group this program is in
# group_id = pid // num_pid_in_group # group_id = pid // num_pid_in_group
# # row-id of the first program in the group # # row-id of the first program in the group
# first_pid_m = group_id * GROUP_SIZE_M # first_pid_m = group_id * GROUP_SIZE_M
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller # # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# # *within groups*, programs are ordered in a column-major order # # *within groups*, programs are ordered in a column-major order
# # row-id of the program in the *launch grid* # # row-id of the program in the *launch grid*
# pid_m = first_pid_m + (pid % group_size_m) # pid_m = first_pid_m + (pid % group_size_m)
@@ -141,6 +141,7 @@ You will specifically learn about:
# #
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -152,18 +153,19 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the # - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs # provided configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@@ -185,7 +187,7 @@ def matmul_kernel(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr, ACTIVATION: tl.constexpr,
): ):
"""Kernel for computing the matmul C = A x B. """Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N) A has shape (M, K), B has shape (K, N) and C has shape (M, N)
""" """
@@ -196,16 +198,16 @@ def matmul_kernel(
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m) pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction # We will advance this pointer as we move in the K direction
# and accumulate # and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
@@ -213,8 +215,8 @@ def matmul_kernel(
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix # Iterate to compute a block of the C matrix
@@ -223,8 +225,8 @@ def matmul_kernel(
# `accumulator` will be converted back to fp16 after the loop # `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K): for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here. # Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K, # This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an # this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results. # error or (worse!) incorrect results.
a = tl.load(a_ptrs) a = tl.load(a_ptrs)
@@ -236,7 +238,7 @@ def matmul_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here # you can fuse arbitrary activation functions here
# while the accumulator is still in FP32 ! # while the accumulator is still in FP32 !
if meta['ACTIVATION']: if meta['ACTIVATION']:
accumulator = meta['ACTIVATION'](accumulator) accumulator = meta['ACTIVATION'](accumulator)
c = accumulator.to(tl.float16) c = accumulator.to(tl.float16)

View File

@@ -13,7 +13,7 @@ whose state is generally composed of a bit mask tensor of the same shape as the
# %% # %%
# Baseline # Baseline
# ------------- # -------------
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance # The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
# of deep neural networks in low-data regime (i.e. regularization). # of deep neural networks in low-data regime (i.e. regularization).
# #
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the # It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
@@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate import tabulate
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def _dropout( def _dropout(
x_ptr, # pointer to the input x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero p, # probability that an element of `x` is changed to zero
**meta, **meta,
): ):
BLOCK_SIZE = meta['BLOCK_SIZE'] BLOCK_SIZE = meta['BLOCK_SIZE']
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output return output
# Input tensor # Input tensor
x = torch.randn(size=(10,)).cuda() x = torch.randn(size=(10,)).cuda()
# Dropout mask # Dropout mask
@@ -88,7 +91,7 @@ print(tabulate.tabulate([
# of persisting randomness across multiple invocations of the kernel. # of persisting randomness across multiple invocations of the kernel.
# #
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the # Pseudorandom number generation in Triton is simple! In this tutorial we will use the
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` # :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides # values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
# other :ref:`random number generation strategies <Random Number Generation>`. # other :ref:`random number generation strategies <Random Number Generation>`.
# #
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
# #
# Let's put it all together. # Let's put it all together.
@triton.jit @triton.jit
def _seeded_dropout( def _seeded_dropout(
x_ptr, x_ptr,

View File

@@ -4,15 +4,17 @@ Layer Normalization
""" """
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
# Forward Pass # Forward Pass
@triton.jit @triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] BLOCK_SIZE = META['BLOCK_SIZE']
# position of elements processed by this program # position of elements processed by this program
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE) cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
# offset data pointers to start at the row of interest # offset data pointers to start at the row of interest
@@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
mean = tl.sum(x, axis=0) / N mean = tl.sum(x, axis=0) / N
# compute std # compute std
xmean = tl.where(mask, x - mean, 0.) xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
xhat = xmean*rstd xhat = xmean * rstd
# write-back mean/rstd # write-back mean/rstd
tl.store(M + row, mean) tl.store(M + row, mean)
tl.store(V + row, rstd) tl.store(V + row, rstd)
@@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
# Backward pass (DX + partial DW + partial DB) # Backward pass (DX + partial DW + partial DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
stride, N, eps, stride, N, eps,
**META): **META):
GROUP_SIZE_M = META['GROUP_SIZE_M'] GROUP_SIZE_M = META['GROUP_SIZE_M']
BLOCK_SIZE_N = META['BLOCK_SIZE_N'] BLOCK_SIZE_N = META['BLOCK_SIZE_N']
# position of elements processed by this program # position of elements processed by this program
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N) cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N mask = cols < N
# offset data pointers to start at the row of interest # offset data pointers to start at the row of interest
X += row * stride X += row * stride
DY += row * stride DY += row * stride
DX += row * stride DX += row * stride
# offset locks and weight/bias gradient pointer # offset locks and weight/bias gradient pointer
@@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
# these buffers stay in the L2, which allow this kernel # these buffers stay in the L2, which allow this kernel
# to be fast # to be fast
lock_id = row % GROUP_SIZE_M lock_id = row % GROUP_SIZE_M
Lock += lock_id Lock += lock_id
Count = Lock + GROUP_SIZE_M Count = Lock + GROUP_SIZE_M
DW = DW + lock_id*N + cols DW = DW + lock_id * N + cols
DB = DB + lock_id*N + cols DB = DB + lock_id * N + cols
# load data to SRAM # load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row) mean = tl.load(M + row)
rstd = tl.load(V + row) rstd = tl.load(V + row)
# compute dx # compute dx
xhat = (x - mean)*rstd xhat = (x - mean) * rstd
wdy = w * dy wdy = w * dy
xhat = tl.where(mask, xhat, 0.) xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy , 0.) wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat*mean1 + mean2))*rstd dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx # write-back dx
tl.store(DX + cols, dx, mask=mask) tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db # accumulate partial sums for dw/db
partial_dw = (dy*xhat).to(w.dtype) partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype) partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1: while tl.atomic_cas(Lock, 0, 1) == 1:
pass pass
@@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
tl.atomic_xchg(Lock, 0) tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB) # Backward pass (total DW + total DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0) pid = tl.program_id(0)
BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M): for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
mask = (rows[:, None] < M) & (cols[None, :] < N) mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None]*N + cols[None, :] offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.) dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0) sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0) sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N) tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols<N) tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function): class LayerNorm(torch.autograd.Function):
@@ -129,19 +134,19 @@ class LayerNorm(torch.autograd.Function):
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda') rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel # Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size() MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE: if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel # enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps, x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps ctx.num_warps = num_warps
ctx.eps = eps ctx.eps = eps
return y return y
@staticmethod @staticmethod
@@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function):
if N <= 4096: GROUP_SIZE_M = 128 if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256 if N <= 1024: GROUP_SIZE_M = 256
# allocate output # allocate output
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy) dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics # enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB # also compute partial sums for DW and DB
@@ -166,14 +171,14 @@ class LayerNorm(torch.autograd.Function):
M, N = x_arg.shape M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps, x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE, BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M, GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps) num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel # accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M = 32, BLOCK_SIZE_M=32,
BLOCK_SIZE_N = 128) BLOCK_SIZE_N=128)
return dx, None, dw, db, None return dx, None, dw, db, None
@@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data # create data
x_shape = (M, N) x_shape = (M, N)
w_shape = (x_shape[-1], ) w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x) dy = .1 * torch.randn_like(x)
x.requires_grad_(True) x.requires_grad_(True)
# forward pass # forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps) y_tri = layer_norm(x, w_shape, weight, bias, eps)
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['N'], x_names=['N'],
@@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
) )
) )
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data # create data
x_shape = (M, N) x_shape = (M, N)
w_shape = (x_shape[-1], ) w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x) dy = .1 * torch.randn_like(x)
x.requires_grad_(True) x.requires_grad_(True)
# utility functions # utility functions
if provider == 'triton': if provider == 'triton':
@@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
y_fwd = lambda: apex_layer_norm(x) y_fwd = lambda: apex_layer_norm(x)
# forward pass # forward pass
if mode == 'forward': if mode == 'forward':
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass # backward pass
if mode == 'backward': if mode == 'backward':
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd() y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500) grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True) bench_layer_norm.run(save_path='.', print_data=True)