[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton import triton
# ------------------------------- # -------------------------------
@@ -8,18 +9,18 @@ import triton
nt = {False: 'n', True: 't'} nt = {False: 'n', True: 't'}
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names=['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64, 128], line_vals=[16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'], line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS', ylabel='TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode, args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for AT in [False] for BT in [False] \ for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense'] for op_mode in ['dsd'] for layout_mode in ['dense']
] ]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1 Z, H = 1, 1
make_layout = { make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode] }[layout_mode]
# create layout # create layout
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = { num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ 'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ 'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block 'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12 }[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms) return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N'], x_names=['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64], line_vals=[16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'], line_names=['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{layout_mode}-square', plot_name=f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for layout_mode in ['dense', 'tril'] for layout_mode in ['dense', 'tril']
] ]

View File

@@ -1,17 +1,18 @@
import torch import torch
import triton import triton
confs = [ confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['N'], x_names=['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider', line_arg='provider',
line_vals = ['triton', 'torch'], line_vals=['triton', 'torch'],
line_names = ['Triton', 'Torch'], line_names=['Triton', 'Torch'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{mode}-2048', plot_name=f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\ )
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
@@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9) num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3 gbps = lambda ms: num_gb / ms * 1e3
# forward pass # forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider] 'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward': if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward': if mode == 'backward':

View File

@@ -1,6 +1,6 @@
import triton
import torch import torch
import os
import triton
def rounded_linspace(low, high, steps, div): def rounded_linspace(low, high, steps, div):
@@ -29,16 +29,16 @@ square_confs = [
transformer_confs = [ transformer_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=[x], x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128), x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider", line_arg="provider",
line_vals=["cublas", "triton", "cutlass"], line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"], line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\ ) for NK in [12288]
for i, x in enumerate(["N", "K"])\ for i, x in enumerate(["N", "K"])
for M in [2048] for M in [2048]
] ]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t() if AT:
if BT: b = b.t() a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9 tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas": if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try: try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms) return tflops(ms), tflops(max_ms), tflops(min_ms)
except: except Exception:
return None return None
return None return None

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import sys
import os
import inspect import inspect
import os
import sys
import triton import triton

View File

@@ -1,20 +1,19 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn import distutils.spawn
import urllib.request import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm(): def get_llvm():
# tries to find system LLVM # tries to find system LLVM
@@ -23,7 +22,7 @@ def get_llvm():
paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None] paths = [p for p in paths if p is not None]
if paths: if paths:
return '', '' return '', ''
# download if nothing is installed # download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp' dir = '/tmp'
@@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir): if not os.path.exists(llvm_library_dir):
try: try:
shutil.rmtree(os.path.join(dir, name)) shutil.rmtree(os.path.join(dir, name))
except: except Exception:
pass pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...') print('downloading and extracting ' + url + '...')
@@ -96,7 +95,7 @@ class CMakeBuild(build_ext):
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
#'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
] ]

View File

@@ -1,14 +1,18 @@
from numpy import record import triton.language as tl
import torch
import triton
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch
from numpy import record
import triton
####################### #######################
# Utilities # Utilities
####################### #######################
def nvsmi(attrs): def nvsmi(attrs):
attrs = ','.join(attrs) attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -23,48 +27,51 @@ def nvsmi(attrs):
####################### #######################
matmul_data = { matmul_data = {
# square # square
(256 , 256 , 256 ) : {'v100': 0.027}, (256, 256, 256): {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.158}, (512, 512, 512): {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466}, (1024, 1024, 1024): {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680}, (2048, 2048, 2048): {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831}, (4096, 4096, 4096): {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.849}, (8192, 8192, 8192): {'v100': 0.849},
# tall-skinny # tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128}, (16, 1024, 1024): {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0883}, (16, 4096, 4096): {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101}, (16, 8192, 8192): {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073}, (64, 1024, 1024): {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.270}, (64, 4096, 4096): {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360}, (64, 8192, 8192): {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692}, (1024, 64, 1024): {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.264}, (4096, 64, 4096): {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323}, (8192, 64, 8192): {'v100': 0.323},
# # deep reductions # # deep reductions
# (64 , 64 , 16384) : {'v100': 0.}, # (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.}, # (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.}, # (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.}, # (256 , 256 , 32768) : {'v100': 0.},
} }
@pytest.mark.parametrize('M, N, K', matmul_data.keys()) @pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K): def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100'] ref_gpu_util = matmul_data[(M, N, K)]['v100']
cur_sm_clock = nvsmi(['clocks.current.sm'])[0] cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350 ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda') a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b) fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9 cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
####################### #######################
# Element-Wise # Element-Wise
####################### #######################
import triton.language as tl
@triton.jit @triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, def _add(x_ptr, y_ptr, output_ptr, n_elements,
@@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = { elementwise_data = {
1024*16 : {'v100': 0.0219}, 1024 * 16: {'v100': 0.0219},
1024*64 : {'v100': 0.0791}, 1024 * 64: {'v100': 0.0791},
1024*256 : {'v100': 0.243}, 1024 * 256: {'v100': 0.243},
1024*1024 : {'v100': 0.534}, 1024 * 1024: {'v100': 0.534},
1024*4096 : {'v100': 0.796}, 1024 * 4096: {'v100': 0.796},
1024*16384: {'v100': 0.905}, 1024 * 16384: {'v100': 0.905},
1024*65536: {'v100': 0.939}, 1024 * 65536: {'v100': 0.939},
} }
@pytest.mark.parametrize('N', elementwise_data.keys()) @pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N): def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100'] ref_gpu_util = elementwise_data[N]['v100']
cur_mem_clock = nvsmi(['clocks.current.memory'])[0] cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877 ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3 max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda') z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z) x = torch.randn_like(z)
@@ -102,7 +110,6 @@ def test_elementwise(N):
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, SIZE: tl.constexpr): def kernel(X, SIZE: tl.constexpr):
pass pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, SIZE: tl.constexpr): def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr): def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# --------------- # ---------------
# test binary ops # test binary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes for dtype_x in dtypes
for dtype_y in dtypes for dtype_y in dtypes
]) ])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'): def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y' expr = f' x {op} y'
@@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y", @pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
) )
def test_floordiv(dtype_x, dtype_y, device='cuda'): def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry # Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a # through to //, so we have to use a nonstandard expression to get a
@@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops # test compare ops
# --------------- # ---------------
ops = ['==', '!=', '>', '<', '>=', '<='] ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', op, mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
('nan', 'nan')]
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y' expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin' # 'exp', 'log', 'cos', 'sin'
# ]) # ])
@pytest.mark.parametrize("expr", [ @pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin' 'exp', 'log', 'cos', 'sin'
]) ])
@@ -368,8 +376,8 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [ @pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d) (f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16'] for d in ['int32', 'uint32', 'uint16']
]) ])
def test_index1d(expr, dtype_str, device='cuda'): def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':') rank_x = expr.count(':')
@@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
@triton.jit @triton.jit
def fn(a, b): def fn(a, b):
return a + b, \ return a + b, \
a - b, \ a - b, \
a * b a * b
def test_tuples(): def test_tuples():
@@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False) (dtype_x, dtype_z, False)
for dtype_x in dtypes for dtype_x in dtypes
for dtype_z in dtypes for dtype_z in dtypes
] + [ ] + [
('float32', 'bfloat16', False), ('float32', 'bfloat16', False),
('bfloat16', 'float32', False), ('bfloat16', 'float32', False),
@@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
@triton.jit @triton.jit
def kernel(X, Z, BITCAST: tl.constexpr): def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X) x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast = BITCAST) z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z) tl.store(Z, z)
# triton result # triton result
@@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# --------------- # ---------------
# test reduce # test reduce
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape", @pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \ [(dtype, shape)
for dtype in dtypes\ for dtype in dtypes
for shape in [128, 512]]) for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'): def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel # triton kernel
@@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M) range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N) range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS) z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z) tl.store(Z + range_m, z)
# input # input
@@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# --------------- # ---------------
# test permute # test permute
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \ [(dtype, shape, perm)
for dtype in ['float32']\ for dtype in ['float32']
for shape in [(128, 128)]\ for shape in [(128, 128)]
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
# triton kernel # triton kernel
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot # test dot
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'): def test_dot(epilogue, device='cuda'):
# triton kernel # triton kernel
@@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix', ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS = epilogue=='add-rows', ADD_ROWS=epilogue == 'add-rows',
ADD_COLS = epilogue=='add-cols') ADD_COLS=epilogue == 'add-cols')
# torch result # torch result
z_ref = np.matmul(x, y) z_ref = np.matmul(x, y)
if epilogue == 'add-matrix': if epilogue == 'add-matrix':
z_ref += z z_ref += z
if epilogue == 'add-rows': if epilogue == 'add-rows':
z_ref += z[:,0][:, None] z_ref += z[:, 0][:, None]
if epilogue == 'add-cols': if epilogue == 'add-cols':
z_ref += z[0,:][None, :] z_ref += z[0, :][None, :]
# compare # compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized # make sure ld/st are vectorized
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
def test_dot_without_load(): def test_dot_without_load():
@triton.jit @triton.jit
def kernel(out): def kernel(out):
@@ -713,28 +727,30 @@ def test_dot_without_load():
b = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32) c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b) c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c) tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda") out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out) kernel[(1,)](out)
# --------------- # ---------------
# test arange # test arange
# --------------- # ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'): def test_arange(start, device='cuda'):
BLOCK = 128 BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit @triton.jit
def _kernel(z, BLOCK: tl.constexpr, def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr): START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK) off = tl.arange(0, BLOCK)
val = tl.arange(START, END) val = tl.arange(START, END)
tl.store(z + off, val) tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
# --------------- # ---------------
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# --------------- # ---------------
# 'bfloat16': torch.bfloat16, # 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run. # Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32 M = 32
@@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N) N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K) K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs. # Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w) o = tl.dot(x, w)
# Store output # Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out, pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0], in1.stride()[0],
in2.stride()[0], in2.stride()[0],
out.stride()[0], out.stride()[0],
in1.numel(), in1.numel(),
in2.numel(), in2.numel(),
out.numel(), out.numel(),
M=M, N=N, K=K) M=M, N=N, K=K)
reference_out =torch.matmul(in1, in2) reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out) triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache): def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda') src = torch.empty(128, device='cuda')
@@ -796,8 +815,8 @@ def test_load_cache_modifier(cache):
@triton.jit @triton.jit
def _kernel(dst, src, CACHE: tl.constexpr): def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128) offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=CACHE) x = tl.load(src + offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x) tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache) pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
@@ -830,11 +849,14 @@ def test_load_cache_modifier(cache):
# --------------- # ---------------
# test default # test default
# --------------- # ---------------
#TODO: can't be local to test_default # TODO: can't be local to test_default
@triton.jit @triton.jit
def _impl(value = 10): def _impl(value=10):
return value return value
def test_default(): def test_default():
value = 5 value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -851,7 +873,9 @@ def test_default():
# --------------- # ---------------
# test noop # test noop
#---------------- # ----------------
def test_noop(device='cuda'): def test_noop(device='cuda'):
@triton.jit @triton.jit
def kernel(x): def kernel(x):
@@ -861,9 +885,9 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [ @pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
]) ])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None: def test_value_specialization(value: int, value_type: str, device='cuda') -> None:

View File

@@ -1,16 +1,17 @@
import torch import numpy as np
import triton
import triton.language as tl
import pytest import pytest
import scipy.stats import scipy.stats
import numpy as np import torch
from numpy.random import Philox from numpy.random import Philox
import triton
import triton.language as tl
##################################### #####################################
## Reference Philox Implementation # Reference Philox Implementation
##################################### #####################################
class PhiloxConfig: class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
##################################### #####################################
## Unit Tests # Unit Tests
##################################### #####################################
BLOCK = 1024 BLOCK = 1024
# test generation of random uint32 # test generation of random uint32
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\ [(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
) )
def test_randint(size, seed, device='cuda'): def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(','))) size = list(map(int, size.split(',')))
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref assert out_tri == out_ref
# test uniform PRNG # test uniform PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_rand(size, seed, device='cuda'): def test_rand(size, seed, device='cuda'):
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG # test normal PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_randn(size, seed, device='cuda'): def test_randn(size, seed, device='cuda'):
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result # torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx # broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask) if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx) M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf") rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :] # rx += kp_mask[:, None, None, :]

View File

@@ -1,14 +1,16 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode", @pytest.mark.parametrize("M, N, dtype, mode",
[ [
(M, N, dtype, mode) for M in [1024, 821] (M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000] for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\ for dtype in ['float16', 'float32']
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
) )
def test_op(M, N, dtype, mode): def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]

View File

@@ -1,8 +1,10 @@
import pytest
import itertools import itertools
import triton
import pytest
import torch import torch
import triton
@pytest.mark.parametrize( @pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
@@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
K = BLOCK_K * SPLIT_K if K is None else K K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs # allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a a = a.t() if AT else a
b = b.t() if BT else b b = b.t() if BT else b
# run test # run test
th_c = torch.matmul(a, b) th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c) triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os import os
import shutil import shutil
import pytest import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp" tmpdir = ".tmp"
@triton.jit @triton.jit
def function_1(i): def function_1(i):
i = i + 1 i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1 i = i + 1
return i return i
@triton.jit @triton.jit
def kernel(X, i, BLOCK: tl.constexpr): def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
@triton.jit(do_not_specialize=["i"]) @triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr): def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
def apply_src_change(target, old, new): def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash') delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash') delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old) target.src = target.src.replace(new, old)
return ret return ret
def test_nochange(): def test_nochange():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1') updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated assert baseline == updated
def test_toplevel_change(): def test_toplevel_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2') updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def test_nested1_change(): def test_nested1_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2') updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def reset_tmp_dir(): def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir): if os.path.exists(tmpdir):
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def test_reuse(): def test_reuse():
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1
@@ -78,6 +90,7 @@ def test_reuse():
@pytest.mark.parametrize('mode', ['enable', 'disable']) @pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode): def test_specialize(mode):
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess import subprocess
import triton.language as tl
import numpy as np import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix(): def get_p2p_matrix():

View File

@@ -1,26 +1,26 @@
import ast import ast
import builtins import builtins
import dbm
import functools import functools
import inspect
import struct
import sys
import textwrap
import hashlib import hashlib
import inspect
import os import os
import pickle import pickle
import struct
import subprocess import subprocess
import os import sys
import tempfile
import textwrap
import time
import warnings import warnings
from .tools.disasm import extract from typing import Dict, Optional
import torch import torch
from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from filelock import FileLock from .tools.disasm import extract
import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
@@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor):
arg_names, kwarg_names = self.visit(node.args) arg_names, kwarg_names = self.visit(node.args)
# initialize defaults # initialize defaults
for i, default_value in enumerate(node.args.defaults): for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1] arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation annotation = arg_node.annotation
name = arg_node.arg name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store()) st_target = ast.Name(id=name, ctx=ast.Store())
@@ -135,7 +135,6 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(fn.args[idx]) arg_values.append(fn.args[idx])
idx += 1 idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
if inline: if inline:
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
# default: call visit_Assign # default: call visit_Assign
return self.visit_Assign(node) return self.visit_Assign(node)
def visit_Assign(self, node): def visit_Assign(self, node):
_names = [] _names = []
for target in node.targets: for target in node.targets:
@@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
if else_bb: if else_bb:
self.builder.set_insert_block(else_bb) self.builder.set_insert_block(else_bb)
is_terminator = self.visit_compound_statement(node.orelse) is_terminator = self.visit_compound_statement(node.orelse)
#TODO: last statement is a terminator? # TODO: last statement is a terminator?
if not is_terminator: if not is_terminator:
self.builder.br(endif_bb) self.builder.br(endif_bb)
self.module.seal_block(endif_bb) self.module.seal_block(endif_bb)
@@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),\ self.visit(pos_cond_node),
self.visit(neg_cond_node),\ self.visit(neg_cond_node),
_builder=self.builder) _builder=self.builder)
#cond_node = neg_cond_node #cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
@@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction): if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws) return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core: sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
return fn(*args, **kws) return fn(*args, **kws)
@@ -632,10 +630,14 @@ class Kernel:
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
if N % 16 == 0: return 16 if N % 16 == 0:
if N % 8 == 0: return 8 return 16
if N % 4 == 0: return 4 if N % 8 == 0:
if N % 2 == 0: return 2 return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1 return 1
def __init__(self, fn): def __init__(self, fn):
@@ -675,7 +677,7 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes # attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize} if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
@@ -768,9 +770,8 @@ class Launcher:
return self.kernel(*wargs, **kwargs, grid=self.grid) return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner: class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
''' '''
:param prune_configs_by: a dict of functions that are used to prune configs, fields: :param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time 'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -788,6 +789,7 @@ class Autotuner:
self.hook = lambda args: 0 self.hook = lambda args: 0
if reset_to_zero is not None: if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero] self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args): def _hook(args):
for i in self.reset_idx: for i in self.reset_idx:
args[i].zero_() args[i].zero_()
@@ -814,6 +816,7 @@ class Autotuner:
) )
# augment meta-parameters with tunable ones # augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs) current = dict(meta, **config.kwargs)
def kernel_call(): def kernel_call():
if config.pre_hook: if config.pre_hook:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
@@ -836,9 +839,9 @@ class Autotuner:
top_k = int(len(self.configs) * top_k) top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k: if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time() bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \ timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs} for config in pruned_configs}
bench_end = time.time() bench_end = time.time()
self.bench_time = bench_end - bench_start self.bench_time = bench_end - bench_start
@@ -876,7 +879,7 @@ def version_key():
ptxas_version = '' ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3 # 3
class DependenciesFinder(ast.NodeVisitor): class DependenciesFinder(ast.NodeVisitor):
@@ -917,11 +920,11 @@ class DependenciesFinder(ast.NodeVisitor):
self.ret = (self.ret + func.hash).encode("utf-8") self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest() self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction: class JITFunction:
cache_hook = None cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None): def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
@@ -946,7 +949,6 @@ class JITFunction:
# forward docs # forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
@property @property
@functools.lru_cache() @functools.lru_cache()
def cache_key(self): def cache_key(self):
@@ -1027,6 +1029,7 @@ class Config:
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args. function are args.
""" """
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs self.kwargs = kwargs
self.num_warps = num_warps self.num_warps = num_warps
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
def cdiv(x, y): def cdiv(x, y):
return (x + y - 1) // y return (x + y - 1) // y
def next_power_of_2(n): def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n""" """Return the smallest power of 2 greater than or equal to n"""
n -= 1 n -= 1
@@ -1163,10 +1167,11 @@ def next_power_of_2(n):
###### ######
class TensorWrapper: class TensorWrapper:
def __init__(self, base, dtype): def __init__(self, base, dtype):
self.dtype = dtype self.dtype = dtype
self.base = base self.base = base
self.is_cuda = base.is_cuda self.is_cuda = base.is_cuda
self.device = base.device self.device = base.device

View File

@@ -1,8 +1,8 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps from functools import wraps
import triton
from triton._C.libtriton.triton import frontend, ir
# convert block/dtype to ir values # convert block/dtype to ir values
def _to_ir(x, builder): def _to_ir(x, builder):
@@ -65,7 +65,7 @@ def builtin(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \ if '_builder' not in kwargs or \
kwargs['_builder'] is None: kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs) return fn(*args, **kwargs)
return wrapper return wrapper
@@ -111,6 +111,7 @@ class pointer_dtype:
def __str__(self): def __str__(self):
return f'pointer<{self.element_ty}>' return f'pointer<{self.element_ty}>'
# scalar types # scalar types
int1 = dtype(ir.type.get_int1) int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8) int8 = dtype(ir.type.get_int8)
@@ -489,6 +490,7 @@ def broadcast_to(input, shape, _builder=None):
""" """
return frontend.broadcast_to(input, shape, _builder) return frontend.broadcast_to(input, shape, _builder)
@builtin @builtin
def cat(input, other, _builder=None): def cat(input, other, _builder=None):
""" """
@@ -603,6 +605,7 @@ def _add_atomic_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_atomic_docstr("compare-and-swap") @_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None): def atomic_cas(pointer, cmp, val, _builder=None):
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_xchg(pointer, val, mask=None, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder) return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin @builtin
@_add_atomic_docstr("add") @_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None): def atomic_add(pointer, val, mask=None, _builder=None):
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
def umulhi(x, y, _builder=None): def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder) return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name): def _add_math_1arg_docstr(name):
def _decorator(func): def _decorator(func):
@@ -697,21 +702,25 @@ def _add_math_1arg_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_math_1arg_docstr("exponential") @_add_math_1arg_docstr("exponential")
def exp(x, _builder=None): def exp(x, _builder=None):
return frontend.exp(x, _builder) return frontend.exp(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("natural logarithm") @_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None): def log(x, _builder=None):
return frontend.log(x, _builder) return frontend.log(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("cosine") @_add_math_1arg_docstr("cosine")
def cos(x, _builder=None): def cos(x, _builder=None):
return frontend.cos(x, _builder) return frontend.cos(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("sine") @_add_math_1arg_docstr("sine")
def sin(x, _builder=None): def sin(x, _builder=None):
@@ -742,6 +751,7 @@ def _add_reduction_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_reduction_docstr("maximum") @_add_reduction_docstr("maximum")
def max(input, axis, _builder=None): def max(input, axis, _builder=None):
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
def sum(input, axis, _builder=None): def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder) return frontend.sum(input, axis, _builder)
@builtin @builtin
@_add_reduction_docstr("xor sum") @_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None): def xor_sum(input, axis, _builder=None):
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
def abs(x): def abs(x):
return where(x >= 0, x, -x) return where(x >= 0, x, -x)
@triton.jit @triton.jit
def cdiv(x, div): def cdiv(x, div):
""" """
@@ -871,6 +883,7 @@ def ravel(x):
""" """
return triton.language.reshape(x, [x.type.numel]) return triton.language.reshape(x, [x.type.numel])
@triton.jit @triton.jit
def swizzle2d(i, j, size_i, size_j, size_g): def swizzle2d(i, j, size_i, size_j, size_g):
""" """
@@ -888,7 +901,7 @@ def swizzle2d(i, j, size_i, size_j, size_g):
[9, 11, 13, 15]] [9, 11, 13, 15]]
""" """
# "unrolled index in array" # "unrolled index in array"
ij = i*size_j + j ij = i * size_j + j
# number of elements in `size_g` groups # number of elements in `size_g` groups
# of `size_j` columns # of `size_j` columns
size_gj = size_g * size_j size_gj = size_g * size_j

View File

@@ -1,7 +1,6 @@
import triton import triton
from . import core as tl from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53

View File

@@ -1,4 +1,4 @@
#from .conv import _conv, conv #from .conv import _conv, conv
from .matmul import _matmul, matmul
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul

View File

@@ -1,6 +1,7 @@
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
# ******************************************************** # ********************************************************
# -------------------------------------------------------- # --------------------------------------------------------
@@ -11,6 +12,7 @@ import torch
# -------------------------------------------------------- # --------------------------------------------------------
# ******************************************************** # ********************************************************
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
}) })
@@ -30,25 +32,25 @@ def _sdd_kernel(
block_id = tl.program_id(1) + grid_offset block_id = tl.program_id(1) + grid_offset
lut += block_id * 3 lut += block_id * 3
# offsets # offsets
off_z = tl.program_id(2) # batch off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head off_h = tl.load(lut + 0) # head
# initialize pointers to A # initialize pointers to A
start_am = tl.load(lut + 1) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za \ a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha \ + off_h * stride_ha
+ offs_am[:, None] * stride_ma \ + offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak) + offs_ak[None, :] * stride_ak)
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(lut + 2) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb \ b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb \ + off_h * stride_hb
+ offs_bn[None, :] * stride_nb \ + offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk) + offs_bk[:, None] * stride_bk)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
@@ -69,13 +71,14 @@ def _sdd_kernel(
## ---------------- ## ## ---------------- ##
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc \ pc = C + (off_z * stride_zc
+ block_id * stride_hc \ + block_id * stride_hc
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc) + offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:
@@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3), c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut, Ka, 0, lut,
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4, num_warps=4,
) )
return c return c
@@ -119,6 +122,8 @@ def sdd_lut(layout, block, device):
# This operation uses a look-up table that contains pre-computed pointer increments # This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel. # in order to minimize computations in the inner loop of the matmul kernel.
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dsd_kernel( def _dsd_kernel(
A, B, C, A, B, C,
@@ -132,37 +137,37 @@ def _dsd_kernel(
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
pid_m = tl.program_id(0) pid_m = tl.program_id(0)
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1) num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2) pidz = tl.program_id(2)
header = lut + pid_n * 4 header = lut + pid_n * 4
offset = tl.load(header + 0) offset = tl.load(header + 0)
K = tl.load(header + 1) K = tl.load(header + 1)
column = tl.load(header + 2) column = tl.load(header + 2)
off_h = tl.load(header + 3) off_h = tl.load(header + 3)
pinc = lut + offset pinc = lut + offset
# initialize pointers to A (sparse) # initialize pointers to A (sparse)
block_id = tl.load(pinc + 1) block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M) offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \ pa = A + pidz * stride_az \
+ block_id * stride_ha \ + block_id * stride_ha \
+ offs_am[:, None] * stride_am \ + offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak + offs_ak[None, :] * stride_ak
# initialize pointers to B (dense) # initialize pointers to B (dense)
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc) start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K) offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \ pb = B + pidz * stride_zb \
+ off_h * stride_hb \ + off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \ + offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
@@ -177,7 +182,7 @@ def _dsd_kernel(
b = tl.load(pb, mask=offs_bn[None, :] < DS0) b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
pa += inc_a pa += inc_a
pb += inc_b*stride_bk pb += inc_b * stride_bk
pinc += 2 pinc += 2
inc_a = tl.load(pinc + 1) inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8) inc_a = tl.multiple_of(inc_a, 8)
@@ -185,15 +190,16 @@ def _dsd_kernel(
inc_b = tl.multiple_of(inc_b, 8) inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C # initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M) offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \ pc = C + off_h * stride_hc \
+ pidz * stride_zc \ + pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \ + offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0) tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1: if b.stride(2) != 1 and b.stride(3) != 1:

View File

@@ -1,7 +1,8 @@
import triton.language as tl
import triton
import torch import torch
import triton
import triton.language as tl
def num_warps(n): def num_warps(n):
if n < 512: if n < 512:
@@ -33,10 +34,10 @@ def _forward(
check = rbn < size check = rbn < size
rbmn = tl.where(check, rbn, size - 1) rbmn = tl.where(check, rbn, size - 1)
# block id and column id # block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0) blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1) columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2) rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3) headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X # pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf')) x = tl.load(px, mask=check, other=-float('inf'))
@@ -64,7 +65,7 @@ def _forward(
attn_m = tl.where(attn_m == 0, -float('inf'), 0.) attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m x = x + attn_m
# apply causal mask # apply causal mask
is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
# computation # computation
x = tl.softmax(x) x = tl.softmax(x)
@@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M] grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK = block, BLOCK=block,
APPLY_SCALE = apply_scale, APPLY_SCALE=apply_scale,
APPLY_RPE = apply_rpe, APPLY_RPE=apply_rpe,
APPLY_KP_MASK = apply_kp_mask, APPLY_KP_MASK=apply_kp_mask,
APPLY_ATTN_MASK = apply_attn_mask, APPLY_ATTN_MASK=apply_attn_mask,
KP_MASK_MUL = (kp_mask_mode == 'mul'), KP_MASK_MUL=(kp_mask_mode == 'mul'),
ATTN_MASK_MUL = (attn_mask_mode == 'mul')) ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
ctx.save_for_backward(x, lut) ctx.save_for_backward(x, lut)
@@ -214,7 +215,7 @@ class softmax:
self, x, scale=1., rpe=None, self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None, key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add', key_padding_mask_mode='add', attn_mask_mode='add',
is_causal = False is_causal=False
): ):
if rpe is not None and rpe.dtype != x.dtype: if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype) raise ValueError('relative position embedding must be %s' % x.dtype)

View File

@@ -1,7 +1,9 @@
import os import os
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
def next_power_of_2(n): def next_power_of_2(n):

View File

@@ -1,11 +1,14 @@
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
from .matmul_perf_model import * from .matmul_perf_model import *
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
@@ -15,13 +18,14 @@ def get_configs_io_bound():
num_warps = 2 if block_n <= 64 else 4 num_warps = 2 if block_n <= 64 else 4
configs.append( configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps)) num_stages=num_stages, num_warps=num_warps))
# split_k # split_k
for split_k in [2, 4, 8, 16]: for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs return configs
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
}) })
@@ -30,19 +34,19 @@ def get_configs_io_bound():
# basic configs for compute-bound matmuls # basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(), ] + get_configs_io_bound(),
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={ prune_configs_by={
'prune_num_stages_by' : prune_num_stages, 'prune_num_stages_by': prune_num_stages,
'perf_model': estimate_matmul_time, 'perf_model': estimate_matmul_time,
'top_k': 10 'top_k': 10
}, },
) )
@triton.jit @triton.jit
@@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K,
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K) rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers # pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K): for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K: if EVEN_K:
a = tl.load(A) a = tl.load(A)
b = tl.load(B) b = tl.load(B)

View File

@@ -1,116 +1,121 @@
import heapq
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps): def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS ''' ''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4) total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops return tflops
def estimate_matmul_time( def estimate_matmul_time(
# backend, device, # backend, device,
num_warps, num_stages, num_warps, num_stages,
M, N, K, M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs debug=False, **kwargs
): ):
''' return estimated running time in ms ''' return estimated running time in ms
= max(compute, loading) + store ''' = max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N) num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size # If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N) M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute # time to compute
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput compute_ms = total_ops / tput
# time to load data # time to load data
num_sm = _triton.runtime.num_sm(backend, device) num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas/num_sm) active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache # assume 80% of (following) loads are in L2 cache
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
load_a_l2 = M*K*2*0.8*(num_cta_n-1) load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N*K*2*0.8*(num_cta_m-1) load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
# total # total
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms # loading time in ms
load_ms = total_dram/dram_bw + total_l2/l2_bw load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time # estimate storing time
store_bw = dram_bw * 0.6 # :o store_bw = dram_bw * 0.6 # :o
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1: if SPLIT_K == 1:
store_ms = store_c_dram /store_bw store_ms = store_c_dram / store_bw
else: else:
reduce_bw = store_bw reduce_bw = store_bw
store_ms = store_c_dram/reduce_bw store_ms = store_c_dram / reduce_bw
# c.zero_() # c.zero_()
zero_ms = M*N*2/(1024*1024)/store_bw zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs): def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {} configs_map = {}
for config in configs: for config in configs:
kw = config.kwargs kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map: if key in configs_map:
configs_map[key].append((config, num_stages)) configs_map[key].append((config, num_stages))
else: else:
configs_map[key] = [(config, num_stages)] configs_map[key] = [(config, num_stages)]
pruned_configs = [] pruned_configs = []
for k, v in configs_map.items(): for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80: if cc >= 80:
# compute cycles (only works for ampere GPUs) # compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16) mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas/min(4, num_warps) * 8 mma_cycles = mmas / min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter? ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency/mma_cycles optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages # nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest: for n in nearest:
pruned_configs.append(n[0]) pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2 else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0] random_config = v[0][0]
random_config.num_stages = 2 random_config.num_stages = 2
pruned_configs.append(random_config) pruned_configs.append(random_config)
return pruned_configs return pruned_configs

View File

@@ -1,10 +1,11 @@
import torch
import os import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess import subprocess
import sys import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
try: try:
import triton._C.libtriton.cutlass as _cutlass import triton._C.libtriton.cutlass as _cutlass
@@ -13,6 +14,7 @@ except ImportError:
_cutlass = None _cutlass = None
has_cutlass = False has_cutlass = False
def catch_oor(kernel, pytest_handle=None): def catch_oor(kernel, pytest_handle=None):
try: try:
res = kernel() res = kernel()
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function # run function
dtype = str(a.dtype).split('.')[-1] dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,\ M, N, Ka,
a.stride(0), a.stride(1),\ a.stride(0), a.stride(1),
b.stride(0), b.stride(1),\ b.stride(0), b.stride(1),
c.stride(0), c.stride(1),\ c.stride(0), c.stride(1),
dtype, dtype, dtype, dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream) a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''): def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt import numpy.testing as npt
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@@ -93,6 +96,7 @@ def nvsmi(attrs):
ret = [int(x) for x in ret] ret = [int(x) for x in ret]
return ret return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
""" """
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
torch.cuda.synchronize() torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5 estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat # compute number of warmup and repeat
n_warmup = max(1, int(warmup/estimate_ms)) n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep/estimate_ms)) n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear # We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 # before each kernel call to make sure that the L2
# doesn't contain any input data before the run # doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up # Warm-up
for _ in range(n_warmup): for _ in range(n_warmup):
@@ -161,6 +165,7 @@ class Benchmark:
""" """
This class is used by the :code:`perf_report` function to generate line plots with a concise API. This class is used by the :code:`perf_report` function to generate line plots with a concise API.
""" """
def __init__( def __init__(
self, self,
x_names, x_names,
@@ -224,9 +229,10 @@ class Mark:
self.benchmarks = benchmarks self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data): def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os
y_mean = bench.line_names y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names] y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names]
@@ -259,7 +265,7 @@ class Mark:
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel) ax.set_ylabel(bench.ylabel)
#ax.set_title(bench.plot_name) # ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear") ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear") ax.set_yscale("log" if bench.y_log else "linear")
if show_plots: if show_plots:
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper
def get_dram_gbps(backend=None, device=None): def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s ''' ''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA # assert backend == CUDA
@@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None):
device = torch.cuda.current_device() device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps return bw_gbps
def get_max_tensorcore_tflops(backend, device): def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16 # assume fp32 += fp16*fp16
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
if cc < 80: if cc < 80:
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else: else:
ops_per_sub_core = 512 ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
return tflops return tflops

View File

@@ -21,8 +21,8 @@
# SOFTWARE. # SOFTWARE.
import argparse import argparse
import subprocess
import re import re
import subprocess
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')

View File

@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel # Compute Kernel
# -------------------------- # --------------------------
from triton.language.core import constexpr
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -16,6 +16,8 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation: # Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch import torch
@@ -59,9 +61,6 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the # power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes: # memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit @triton.jit
def softmax_kernel( def softmax_kernel(
@@ -136,7 +135,7 @@ y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1) y_torch = torch.softmax(x, axis=1)
print(torch.allclose(y_triton, y_torch)) print(torch.allclose(y_triton, y_torch))
#%% # %%
# As expected, the results are identical. # As expected, the results are identical.
# %% # %%

View File

@@ -141,6 +141,7 @@ You will specifically learn about:
# #
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -152,18 +153,19 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the # - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs # provided configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@@ -185,7 +187,7 @@ def matmul_kernel(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr, ACTIVATION: tl.constexpr,
): ):
"""Kernel for computing the matmul C = A x B. """Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N) A has shape (M, K), B has shape (K, N) and C has shape (M, N)
""" """
@@ -213,8 +215,8 @@ def matmul_kernel(
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Iterate to compute a block of the C matrix # Iterate to compute a block of the C matrix

View File

@@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate import tabulate
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def _dropout( def _dropout(
x_ptr, # pointer to the input x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero p, # probability that an element of `x` is changed to zero
**meta, **meta,
): ):
BLOCK_SIZE = meta['BLOCK_SIZE'] BLOCK_SIZE = meta['BLOCK_SIZE']
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output return output
# Input tensor # Input tensor
x = torch.randn(size=(10,)).cuda() x = torch.randn(size=(10,)).cuda()
# Dropout mask # Dropout mask
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
# #
# Let's put it all together. # Let's put it all together.
@triton.jit @triton.jit
def _seeded_dropout( def _seeded_dropout(
x_ptr, x_ptr,

View File

@@ -4,15 +4,17 @@ Layer Normalization
""" """
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
# Forward Pass # Forward Pass
@triton.jit @triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] BLOCK_SIZE = META['BLOCK_SIZE']
# position of elements processed by this program # position of elements processed by this program
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE) cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
# offset data pointers to start at the row of interest # offset data pointers to start at the row of interest
@@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
mean = tl.sum(x, axis=0) / N mean = tl.sum(x, axis=0) / N
# compute std # compute std
xmean = tl.where(mask, x - mean, 0.) xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N var = tl.sum(xmean * xmean, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
xhat = xmean*rstd xhat = xmean * rstd
# write-back mean/rstd # write-back mean/rstd
tl.store(M + row, mean) tl.store(M + row, mean)
tl.store(V + row, rstd) tl.store(V + row, rstd)
@@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
# Backward pass (DX + partial DW + partial DB) # Backward pass (DX + partial DW + partial DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
stride, N, eps, stride, N, eps,
**META): **META):
GROUP_SIZE_M = META['GROUP_SIZE_M'] GROUP_SIZE_M = META['GROUP_SIZE_M']
BLOCK_SIZE_N = META['BLOCK_SIZE_N'] BLOCK_SIZE_N = META['BLOCK_SIZE_N']
# position of elements processed by this program # position of elements processed by this program
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N) cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N mask = cols < N
# offset data pointers to start at the row of interest # offset data pointers to start at the row of interest
X += row * stride X += row * stride
DY += row * stride DY += row * stride
DX += row * stride DX += row * stride
# offset locks and weight/bias gradient pointer # offset locks and weight/bias gradient pointer
@@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
# these buffers stay in the L2, which allow this kernel # these buffers stay in the L2, which allow this kernel
# to be fast # to be fast
lock_id = row % GROUP_SIZE_M lock_id = row % GROUP_SIZE_M
Lock += lock_id Lock += lock_id
Count = Lock + GROUP_SIZE_M Count = Lock + GROUP_SIZE_M
DW = DW + lock_id*N + cols DW = DW + lock_id * N + cols
DB = DB + lock_id*N + cols DB = DB + lock_id * N + cols
# load data to SRAM # load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row) mean = tl.load(M + row)
rstd = tl.load(V + row) rstd = tl.load(V + row)
# compute dx # compute dx
xhat = (x - mean)*rstd xhat = (x - mean) * rstd
wdy = w * dy wdy = w * dy
xhat = tl.where(mask, xhat, 0.) xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy , 0.) wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat*mean1 + mean2))*rstd dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx # write-back dx
tl.store(DX + cols, dx, mask=mask) tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db # accumulate partial sums for dw/db
partial_dw = (dy*xhat).to(w.dtype) partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype) partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1: while tl.atomic_cas(Lock, 0, 1) == 1:
pass pass
@@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
tl.atomic_xchg(Lock, 0) tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB) # Backward pass (total DW + total DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0) pid = tl.program_id(0)
BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M): for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
mask = (rows[:, None] < M) & (cols[None, :] < N) mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None]*N + cols[None, :] offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.) dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0) sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0) sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N) tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols<N) tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function): class LayerNorm(torch.autograd.Function):
@@ -129,7 +134,7 @@ class LayerNorm(torch.autograd.Function):
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda') rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel # Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size() MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE: if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps # heuristics for number of warps
@@ -140,8 +145,8 @@ class LayerNorm(torch.autograd.Function):
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps ctx.num_warps = num_warps
ctx.eps = eps ctx.eps = eps
return y return y
@staticmethod @staticmethod
@@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function):
if N <= 4096: GROUP_SIZE_M = 128 if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256 if N <= 1024: GROUP_SIZE_M = 256
# allocate output # allocate output
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy) dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics # enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB # also compute partial sums for DW and DB
@@ -172,8 +177,8 @@ class LayerNorm(torch.autograd.Function):
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel # accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M = 32, BLOCK_SIZE_M=32,
BLOCK_SIZE_N = 128) BLOCK_SIZE_N=128)
return dx, None, dw, db, None return dx, None, dw, db, None
@@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
# create data # create data
x_shape = (M, N) x_shape = (M, N)
w_shape = (x_shape[-1], ) w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x) dy = .1 * torch.randn_like(x)
x.requires_grad_(True) x.requires_grad_(True)
# forward pass # forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps) y_tri = layer_norm(x, w_shape, weight, bias, eps)
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['N'], x_names=['N'],
@@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
) )
) )
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data # create data
x_shape = (M, N) x_shape = (M, N)
w_shape = (x_shape[-1], ) w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1*torch.randn_like(x) dy = .1 * torch.randn_like(x)
x.requires_grad_(True) x.requires_grad_(True)
# utility functions # utility functions
if provider == 'triton': if provider == 'triton':
@@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
y_fwd = lambda: apex_layer_norm(x) y_fwd = lambda: apex_layer_norm(x)
# forward pass # forward pass
if mode == 'forward': if mode == 'forward':
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
# backward pass # backward pass
if mode == 'backward': if mode == 'backward':
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
y = y_fwd() y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x], rep=500) grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True) bench_layer_norm.run(save_path='.', print_data=True)