[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
@@ -8,18 +9,18 @@ import triton
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg = 'block',
|
||||
line_vals = [16, 32, 64, 128],
|
||||
line_names = ['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel = 'TFLOPS',
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for AT in [False] for BT in [False] \
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
x_names=['M', 'N', 'K'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64, 128],
|
||||
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel='TFLOPS',
|
||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for AT in [False] for BT in [False]
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
]
|
||||
|
||||
|
||||
@@ -27,7 +28,7 @@ square_confs = [
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
@@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode]*1e-12
|
||||
}[op_mode] * 1e-12
|
||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||
|
||||
|
||||
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg = 'block',
|
||||
line_vals = [16, 32, 64],
|
||||
line_names = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
x_names=['M', 'N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64],
|
||||
line_names=['Block16', 'Block32', 'Block64'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{layout_mode}-square',
|
||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
@@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
||||
|
@@ -1,17 +1,18 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
line_arg = 'provider',
|
||||
line_vals = ['triton', 'torch'],
|
||||
line_names = ['Triton', 'Torch'],
|
||||
ylabel = 'GBPS',
|
||||
plot_name = f'{mode}-2048',
|
||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
x_names=['N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{mode}-2048',
|
||||
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
@@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider):
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
gbps = lambda ms: num_gb / ms * 1e3
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
if mode == 'backward':
|
||||
@@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run(print_data=True)
|
||||
bench_op.run(print_data=True)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import triton
|
||||
import torch
|
||||
import os
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
@@ -29,16 +29,16 @@ square_confs = [
|
||||
transformer_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=[x],
|
||||
x_vals = rounded_linspace(NK//16, NK, 32, 128),
|
||||
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
|
||||
line_arg="provider",
|
||||
line_vals=["cublas", "triton", "cutlass"],
|
||||
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [12288]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
for M in [2048]
|
||||
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [12288]
|
||||
for i, x in enumerate(["N", "K"])
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
|
||||
@@ -46,8 +46,10 @@ transformer_confs = [
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||
if provider == "cublas":
|
||||
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
|
@@ -1,29 +1,28 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import sysconfig
|
||||
import platform
|
||||
import subprocess
|
||||
import distutils
|
||||
import glob
|
||||
import tempfile
|
||||
import shutil
|
||||
from distutils.version import LooseVersion
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import urllib.request
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return '', ''
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
dir = '/tmp'
|
||||
@@ -32,7 +31,7 @@ def get_llvm():
|
||||
if not os.path.exists(llvm_library_dir):
|
||||
try:
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
@@ -96,7 +95,7 @@ class CMakeBuild(build_ext):
|
||||
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
|
||||
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
||||
]
|
||||
|
@@ -1,14 +1,18 @@
|
||||
from numpy import record
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from numpy import record
|
||||
|
||||
import triton
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
#######################
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
@@ -23,48 +27,51 @@ def nvsmi(attrs):
|
||||
#######################
|
||||
|
||||
matmul_data = {
|
||||
# square
|
||||
(256 , 256 , 256 ) : {'v100': 0.027},
|
||||
(512 , 512 , 512 ) : {'v100': 0.158},
|
||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||
(8192, 8192, 8192 ) : {'v100': 0.849},
|
||||
# tall-skinny
|
||||
(16 , 1024, 1024 ) : {'v100': 0.0128},
|
||||
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||
(64 , 1024, 1024 ) : {'v100': 0.073},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.270},
|
||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.264},
|
||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
# (64 , 64 , 65536) : {'v100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||
# (256 , 256 , 32768) : {'v100': 0.},
|
||||
# square
|
||||
(256, 256, 256): {'v100': 0.027},
|
||||
(512, 512, 512): {'v100': 0.158},
|
||||
(1024, 1024, 1024): {'v100': 0.466},
|
||||
(2048, 2048, 2048): {'v100': 0.680},
|
||||
(4096, 4096, 4096): {'v100': 0.831},
|
||||
(8192, 8192, 8192): {'v100': 0.849},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'v100': 0.0128},
|
||||
(16, 4096, 4096): {'v100': 0.0883},
|
||||
(16, 8192, 8192): {'v100': 0.101},
|
||||
(64, 1024, 1024): {'v100': 0.073},
|
||||
(64, 4096, 4096): {'v100': 0.270},
|
||||
(64, 8192, 8192): {'v100': 0.360},
|
||||
(1024, 64, 1024): {'v100': 0.0692},
|
||||
(4096, 64, 4096): {'v100': 0.264},
|
||||
(8192, 64, 8192): {'v100': 0.323},
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
# (64 , 64 , 65536) : {'v100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||
# (256 , 256 , 32768) : {'v100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||
def test_matmul(M, N, K):
|
||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
ref_sm_clock = 1350
|
||||
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
|
||||
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
||||
#######################
|
||||
# Element-Wise
|
||||
#######################
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
@@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
|
||||
|
||||
elementwise_data = {
|
||||
1024*16 : {'v100': 0.0219},
|
||||
1024*64 : {'v100': 0.0791},
|
||||
1024*256 : {'v100': 0.243},
|
||||
1024*1024 : {'v100': 0.534},
|
||||
1024*4096 : {'v100': 0.796},
|
||||
1024*16384: {'v100': 0.905},
|
||||
1024*65536: {'v100': 0.939},
|
||||
1024 * 16: {'v100': 0.0219},
|
||||
1024 * 64: {'v100': 0.0791},
|
||||
1024 * 256: {'v100': 0.243},
|
||||
1024 * 1024: {'v100': 0.534},
|
||||
1024 * 4096: {'v100': 0.796},
|
||||
1024 * 16384: {'v100': 0.905},
|
||||
1024 * 65536: {'v100': 0.939},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||
def test_elementwise(N):
|
||||
ref_gpu_util = elementwise_data[N]['v100']
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = 877
|
||||
max_gpu_perf = 512*2*ref_mem_clock*1e-3
|
||||
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
@@ -102,7 +110,6 @@ def test_elementwise(N):
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
@@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
@@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
)
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
)
|
||||
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||
# through to //, so we have to use a nonstandard expression to get a
|
||||
@@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
# test compare ops
|
||||
# ---------------
|
||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
|
||||
# real
|
||||
[
|
||||
(dtype_x, dtype_y, op, 'real', 'real') \
|
||||
for op in ops \
|
||||
for dtype_x in dtypes \
|
||||
for dtype_y in dtypes
|
||||
] + \
|
||||
# NaNs
|
||||
[('float32', 'float32', op, mode_x, mode_y) \
|
||||
for op in ops
|
||||
for mode_x, mode_y in [('nan' , 'real'),
|
||||
('real', 'nan'),
|
||||
('nan' , 'nan')]
|
||||
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
||||
# real
|
||||
[
|
||||
(dtype_x, dtype_y, op, 'real', 'real')
|
||||
for op in ops
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
] +
|
||||
# NaNs
|
||||
[('float32', 'float32', op, mode_x, mode_y)
|
||||
for op in ops
|
||||
for mode_x, mode_y in [('nan', 'real'),
|
||||
('real', 'nan'),
|
||||
('nan', 'nan')]
|
||||
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
@@ -368,8 +376,8 @@ def make_ptr_str(name, shape):
|
||||
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
@@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
a - b, \
|
||||
a * b
|
||||
|
||||
|
||||
def test_tuples():
|
||||
@@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False)
|
||||
for dtype_x in dtypes
|
||||
for dtype_z in dtypes
|
||||
for dtype_x in dtypes
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
@@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
|
||||
# triton result
|
||||
@@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape",
|
||||
[(dtype, shape) \
|
||||
for dtype in dtypes\
|
||||
for shape in [128, 512]])
|
||||
[(dtype, shape)
|
||||
for dtype in dtypes
|
||||
for shape in [128, 512]])
|
||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
@@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.store(Z + range_m, z)
|
||||
# input
|
||||
@@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm) \
|
||||
for dtype in ['float32']\
|
||||
for shape in [(128, 128)]\
|
||||
for perm in [(1, 0)]])
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['float32']
|
||||
for shape in [(128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
# triton kernel
|
||||
@@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'):
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX = epilogue=='add-matrix',
|
||||
ADD_ROWS = epilogue=='add-rows',
|
||||
ADD_COLS = epilogue=='add-cols')
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols')
|
||||
# torch result
|
||||
z_ref = np.matmul(x, y)
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
z_ref += z[:,0][:, None]
|
||||
z_ref += z[:, 0][:, None]
|
||||
if epilogue == 'add-cols':
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref += z[0, :][None, :]
|
||||
# compare
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
@@ -713,28 +727,30 @@ def test_dot_without_load():
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
||||
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
||||
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(z, BLOCK: tl.constexpr,
|
||||
START: tl.constexpr, END: tl.constexpr):
|
||||
off = tl.arange(0, BLOCK)
|
||||
val = tl.arange(START, END)
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# ---------------
|
||||
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
|
||||
# ---------------
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M = 32
|
||||
@@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
N_offsets = tl.arange(0, N)
|
||||
K_offsets = tl.arange(0, K)
|
||||
|
||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
@@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
o = tl.dot(x, w)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
in2.stride()[0],
|
||||
out.stride()[0],
|
||||
in1.numel(),
|
||||
in2.numel(),
|
||||
out.numel(),
|
||||
M=M, N=N, K=K)
|
||||
in1.stride()[0],
|
||||
in2.stride()[0],
|
||||
out.stride()[0],
|
||||
in1.numel(),
|
||||
in2.numel(),
|
||||
out.numel(),
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
@@ -796,8 +815,8 @@ def test_load_cache_modifier(cache):
|
||||
@triton.jit
|
||||
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=CACHE)
|
||||
tl.store(dst+offsets, x)
|
||||
x = tl.load(src + offsets, cache_modifier=CACHE)
|
||||
tl.store(dst + offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
@@ -830,11 +849,14 @@ def test_load_cache_modifier(cache):
|
||||
# ---------------
|
||||
# test default
|
||||
# ---------------
|
||||
#TODO: can't be local to test_default
|
||||
# TODO: can't be local to test_default
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _impl(value = 10):
|
||||
def _impl(value=10):
|
||||
return value
|
||||
|
||||
|
||||
def test_default():
|
||||
value = 5
|
||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
@@ -851,7 +873,9 @@ def test_default():
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
#----------------
|
||||
# ----------------
|
||||
|
||||
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(x):
|
||||
@@ -861,9 +885,9 @@ def test_noop(device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
|
||||
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
|
@@ -1,16 +1,17 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from numpy.random import Philox
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||
# broadcast at_mask to the same shape as rx
|
||||
if is_causal: at_mask = torch.tril(at_mask)
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float("-inf")
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
|
@@ -1,14 +1,16 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
@@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode):
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
@@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
||||
|
@@ -1,13 +1,16 @@
|
||||
import torch
|
||||
import triton
|
||||
from triton.code_gen import JITFunction
|
||||
import triton.language as tl
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_1(i):
|
||||
i = i + 1
|
||||
@@ -20,18 +23,21 @@ def function_2(i):
|
||||
i = i + 1
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["i"])
|
||||
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
def apply_src_change(target, old, new):
|
||||
delattr(kernel.fn, 'hash')
|
||||
delattr(function_1.fn, 'hash')
|
||||
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
|
||||
target.src = target.src.replace(new, old)
|
||||
return ret
|
||||
|
||||
|
||||
def test_nochange():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||
assert baseline == updated
|
||||
|
||||
|
||||
def test_toplevel_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def test_nested1_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
@@ -73,11 +85,12 @@ def test_reuse():
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import subprocess
|
||||
import triton.language as tl
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
|
@@ -1,26 +1,26 @@
|
||||
import ast
|
||||
import builtins
|
||||
import dbm
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
import sys
|
||||
import textwrap
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import struct
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from .tools.disasm import extract
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
import tempfile
|
||||
from typing import Optional, Dict
|
||||
import time
|
||||
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i-1]
|
||||
arg_node = node.args.args[-i - 1]
|
||||
annotation = arg_node.annotation
|
||||
name = arg_node.arg
|
||||
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||
@@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
|
||||
def visit_Assign(self, node):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
@@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if else_bb:
|
||||
self.builder.set_insert_block(else_bb)
|
||||
is_terminator = self.visit_compound_statement(node.orelse)
|
||||
#TODO: last statement is a terminator?
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
@@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
_builder=self.builder)
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||
self.visit(pos_cond_node),
|
||||
self.visit(neg_cond_node),
|
||||
_builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
@@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
return fn(*args, **kws)
|
||||
|
||||
@@ -505,10 +503,10 @@ class Binary:
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
@@ -520,8 +518,8 @@ class LoadedBinary:
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
@@ -632,10 +630,14 @@ class Kernel:
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
if N % 16 == 0:
|
||||
return 16
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def __init__(self, fn):
|
||||
@@ -675,7 +677,7 @@ class Kernel:
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
@@ -705,7 +707,7 @@ class Kernel:
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
@@ -766,13 +768,12 @@ class Launcher:
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
@@ -788,6 +789,7 @@ class Autotuner:
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
@@ -802,7 +804,7 @@ class Autotuner:
|
||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.prune_num_stages_by = prune_num_stages_by
|
||||
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
@@ -814,6 +816,7 @@ class Autotuner:
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
@@ -836,9 +839,9 @@ class Autotuner:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
@@ -876,7 +879,7 @@ def version_key():
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
#########################3
|
||||
# 3
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
@@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
@@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
@@ -946,7 +949,6 @@ class JITFunction:
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
@@ -1027,6 +1029,7 @@ class Config:
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
@@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
@@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
@@ -1099,7 +1102,7 @@ def heuristics(values):
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
@@ -1163,13 +1167,14 @@ def next_power_of_2(n):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import triton
|
||||
from triton._C.libtriton.triton import ir
|
||||
from triton._C.libtriton.triton import frontend
|
||||
from functools import wraps
|
||||
|
||||
import triton
|
||||
from triton._C.libtriton.triton import frontend, ir
|
||||
|
||||
|
||||
# convert block/dtype to ir values
|
||||
def _to_ir(x, builder):
|
||||
@@ -65,7 +65,7 @@ def builtin(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -111,6 +111,7 @@ class pointer_dtype:
|
||||
def __str__(self):
|
||||
return f'pointer<{self.element_ty}>'
|
||||
|
||||
|
||||
# scalar types
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
int8 = dtype(ir.type.get_int8)
|
||||
@@ -331,27 +332,27 @@ class constexpr:
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other.value - self.value
|
||||
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.value * other.value
|
||||
|
||||
def __rmul__(self, other):
|
||||
return other.value * self.value
|
||||
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.value / other.value
|
||||
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return other.value / self.value
|
||||
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self.value // other.value
|
||||
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return other.value // self.value
|
||||
|
||||
#
|
||||
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.value > other.value
|
||||
|
||||
@@ -360,25 +361,25 @@ class constexpr:
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.value >= other.value
|
||||
|
||||
|
||||
def __rge__(self, other):
|
||||
return other.value >= self.value
|
||||
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __rlt__(self, other):
|
||||
return other.value < self.value
|
||||
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
|
||||
|
||||
def __rle__(self, other):
|
||||
return other.value <= self.value
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
|
||||
@@ -489,15 +490,16 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
:param input: The first input block.
|
||||
:type input:
|
||||
:type input:
|
||||
:param other: The second input block.
|
||||
:type other:
|
||||
:type other:
|
||||
"""
|
||||
return frontend.cat(input, other, _builder)
|
||||
|
||||
@@ -508,7 +510,7 @@ def reshape(input, shape, _builder=None):
|
||||
Tries to reshape the given block to a new shape.
|
||||
|
||||
:param input: The input block.
|
||||
:type input:
|
||||
:type input:
|
||||
:param shape: The desired shape.
|
||||
:type shape: Tuple[int]
|
||||
|
||||
@@ -546,7 +548,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
|
||||
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
|
||||
:code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`.
|
||||
|
||||
:code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`.
|
||||
|
||||
@@ -565,7 +567,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, _builder=None):
|
||||
"""
|
||||
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
|
||||
:code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`.
|
||||
|
||||
@@ -600,9 +602,10 @@ def _add_atomic_docstr(name):
|
||||
"""
|
||||
func.__doc__ = docstr.format(name=name)
|
||||
return func
|
||||
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, _builder=None):
|
||||
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
|
||||
def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -694,24 +699,28 @@ def _add_math_1arg_docstr(name):
|
||||
"""
|
||||
func.__doc__ = docstr.format(name=name)
|
||||
return func
|
||||
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("exponential")
|
||||
def exp(x, _builder=None):
|
||||
return frontend.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
return frontend.log(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("cosine")
|
||||
def cos(x, _builder=None):
|
||||
return frontend.cos(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("sine")
|
||||
def sin(x, _builder=None):
|
||||
@@ -739,9 +748,10 @@ def _add_reduction_docstr(name):
|
||||
"""
|
||||
func.__doc__ = docstr.format(name=name)
|
||||
return func
|
||||
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis, _builder=None):
|
||||
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
|
||||
def sum(input, axis, _builder=None):
|
||||
return frontend.sum(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis, _builder=None):
|
||||
@@ -778,7 +789,7 @@ def debug_barrier(_builder=None):
|
||||
@builtin
|
||||
def multiple_of(input, value, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
|
||||
"""
|
||||
return frontend.multiple_of(input, value, _builder)
|
||||
|
||||
@@ -786,7 +797,7 @@ def multiple_of(input, value, _builder=None):
|
||||
@builtin
|
||||
def max_contiguous(input, value, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||
"""
|
||||
return frontend.max_contiguous(input, value, _builder)
|
||||
|
||||
@@ -794,7 +805,7 @@ def max_contiguous(input, value, _builder=None):
|
||||
@builtin
|
||||
def max_contiguous(input, value, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||
Let the compiler knows that the `value` first values in :code:`input` are contiguous.
|
||||
"""
|
||||
return frontend.max_contiguous(input, value, _builder)
|
||||
|
||||
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
|
||||
def abs(x):
|
||||
return where(x >= 0, x, -x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
@@ -871,13 +883,14 @@ def ravel(x):
|
||||
"""
|
||||
return triton.language.reshape(x, [x.type.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
transformes indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
@@ -888,16 +901,16 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i*size_j + j
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = minimum(size_i - off_i, size_g)
|
||||
size_g = minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#from .conv import _conv, conv
|
||||
from .matmul import _matmul, matmul
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from . import blocksparse
|
||||
from .matmul import _matmul, matmul
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .matmul import matmul
|
||||
from .softmax import softmax
|
||||
from .softmax import softmax
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
@@ -11,16 +12,17 @@ import torch
|
||||
# --------------------------------------------------------
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
@@ -30,25 +32,25 @@ def _sdd_kernel(
|
||||
block_id = tl.program_id(1) + grid_offset
|
||||
lut += block_id * 3
|
||||
# offsets
|
||||
off_z = tl.program_id(2) # batch
|
||||
off_h = tl.load(lut + 0) # head
|
||||
|
||||
off_z = tl.program_id(2) # batch
|
||||
off_h = tl.load(lut + 0) # head
|
||||
|
||||
# initialize pointers to A
|
||||
start_am = tl.load(lut + 1)
|
||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
a_ptrs = A + (off_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ak)
|
||||
a_ptrs = A + (off_z * stride_za
|
||||
+ off_h * stride_ha
|
||||
+ offs_am[:, None] * stride_ma
|
||||
+ offs_ak[None, :] * stride_ak)
|
||||
# initialize pointers to B
|
||||
start_bn = tl.load(lut + 2)
|
||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
b_ptrs = B + (off_z * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_nb \
|
||||
+ offs_bk[:, None] * stride_bk)
|
||||
b_ptrs = B + (off_z * stride_zb
|
||||
+ off_h * stride_hb
|
||||
+ offs_bn[None, :] * stride_nb
|
||||
+ offs_bk[:, None] * stride_bk)
|
||||
## ---------------- ##
|
||||
## Inner Loop ##
|
||||
## ---------------- ##
|
||||
@@ -69,13 +71,14 @@ def _sdd_kernel(
|
||||
## ---------------- ##
|
||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||
pc = C + (off_z * stride_zc \
|
||||
+ block_id * stride_hc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc)
|
||||
pc = C + (off_z * stride_zc
|
||||
+ block_id * stride_hc
|
||||
+ offs_cm[:, None] * stride_mc
|
||||
+ offs_cn[None, :] * stride_nc)
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
@@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
)
|
||||
return c
|
||||
@@ -119,50 +122,52 @@ def sdd_lut(layout, block, device):
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
#------------#
|
||||
#- Prologue -#
|
||||
#------------#
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pidz = tl.program_id(2)
|
||||
pidz = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
K = tl.load(header + 1)
|
||||
K = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8) # compiler hint
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8) # compiler hint
|
||||
offs_am = tl.arange(0, TILE_M)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
pa = A + pidz * stride_az \
|
||||
+ block_id * stride_ha \
|
||||
+ offs_am[:, None] * stride_am \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
+ block_id * stride_ha \
|
||||
+ offs_am[:, None] * stride_am \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B (dense)
|
||||
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N)
|
||||
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
|
||||
start_bk = tl.load(pinc)
|
||||
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||
pb = B + pidz * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
## ---------------- ##
|
||||
## Inner Loop ##
|
||||
## ---------------- ##
|
||||
@@ -177,7 +182,7 @@ def _dsd_kernel(
|
||||
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
||||
acc += tl.dot(a, b)
|
||||
pa += inc_a
|
||||
pb += inc_b*stride_bk
|
||||
pb += inc_b * stride_bk
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
@@ -185,15 +190,16 @@ def _dsd_kernel(
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C
|
||||
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N)
|
||||
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
pc = C + off_h * stride_hc \
|
||||
+ pidz * stride_zc \
|
||||
+ offs_cm[:, None] * stride_cm \
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
||||
+ pidz * stride_zc \
|
||||
+ offs_cm[:, None] * stride_cm \
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
@@ -231,7 +237,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out =
|
||||
# exit()
|
||||
return c
|
||||
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
sizes = torch.sum(layout, 2 if trans else 1)
|
||||
head_id, col_id = sizes.nonzero(as_tuple=True)
|
||||
sizes = sizes.flatten()
|
||||
@@ -313,11 +319,11 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# -----------------------------
|
||||
@triton.jit
|
||||
def _dds_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ka,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
DS0, DS1, lut,
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ka,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
||||
):
|
||||
@@ -348,7 +354,7 @@ def _dds_kernel(
|
||||
+ offs_ak[None, :] * stride_ka
|
||||
# initialize pointers to B (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8)
|
||||
block_id = tl.multiple_of(block_id, 8)
|
||||
offs_bn = tl.arange(0, TILE_N)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
ptrs_b = B + pid_z * stride_zb \
|
||||
@@ -429,7 +435,7 @@ class _matmul(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
@@ -499,10 +505,10 @@ class matmul:
|
||||
|
||||
def __call__(self, a, b, out = None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
return c
|
||||
return c
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n < 512:
|
||||
@@ -33,10 +34,10 @@ def _forward(
|
||||
check = rbn < size
|
||||
rbmn = tl.where(check, rbn, size - 1)
|
||||
# block id and column id
|
||||
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
||||
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
||||
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
|
||||
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
||||
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
||||
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
||||
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
||||
# pointers to X
|
||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
x = tl.load(px, mask=check, other=-float('inf'))
|
||||
@@ -64,7 +65,7 @@ def _forward(
|
||||
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
||||
x = x + attn_m
|
||||
# apply causal mask
|
||||
is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm
|
||||
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
|
||||
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
|
||||
# computation
|
||||
x = tl.softmax(x)
|
||||
@@ -127,9 +128,9 @@ class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
kp_mask_mode, attn_mask_mode,
|
||||
ctx, x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
kp_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
spdims, block, lut, maxlut
|
||||
):
|
||||
@@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function):
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = [spdims[0] * spdims[1] * block, M]
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
BLOCK = block,
|
||||
APPLY_SCALE = apply_scale,
|
||||
APPLY_RPE = apply_rpe,
|
||||
APPLY_KP_MASK = apply_kp_mask,
|
||||
APPLY_ATTN_MASK = apply_attn_mask,
|
||||
KP_MASK_MUL = (kp_mask_mode == 'mul'),
|
||||
ATTN_MASK_MUL = (attn_mask_mode == 'mul'))
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
BLOCK=block,
|
||||
APPLY_SCALE=apply_scale,
|
||||
APPLY_RPE=apply_rpe,
|
||||
APPLY_KP_MASK=apply_kp_mask,
|
||||
APPLY_ATTN_MASK=apply_attn_mask,
|
||||
KP_MASK_MUL=(kp_mask_mode == 'mul'),
|
||||
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
|
||||
# save to context
|
||||
ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(x, lut)
|
||||
@@ -211,10 +212,10 @@ class softmax:
|
||||
self.lut_cache = dict()
|
||||
|
||||
def __call__(
|
||||
self, x, scale=1., rpe=None,
|
||||
key_padding_mask=None, attn_mask=None,
|
||||
self, x, scale=1., rpe=None,
|
||||
key_padding_mask=None, attn_mask=None,
|
||||
key_padding_mask_mode='add', attn_mask_mode='add',
|
||||
is_causal = False
|
||||
is_causal=False
|
||||
):
|
||||
if rpe is not None and rpe.dtype != x.dtype:
|
||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||
@@ -224,11 +225,11 @@ class softmax:
|
||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||
lut, maxlut = self.make_lut(x.device)
|
||||
x = _softmax.apply(
|
||||
x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
key_padding_mask_mode, attn_mask_mode,
|
||||
x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
key_padding_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
self.spdims, self.block,
|
||||
self.spdims, self.block,
|
||||
lut, maxlut
|
||||
)
|
||||
return x
|
||||
return x
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
@@ -104,4 +106,4 @@ class _cross_entropy(torch.autograd.Function):
|
||||
return neg_logprobs, None
|
||||
|
||||
|
||||
cross_entropy = _cross_entropy.apply
|
||||
cross_entropy = _cross_entropy.apply
|
||||
|
@@ -1,11 +1,14 @@
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import *
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
@@ -14,14 +17,15 @@ def get_configs_io_bound():
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||
})
|
||||
@@ -30,26 +34,26 @@ def get_configs_io_bound():
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'prune_num_stages_by' : prune_num_stages,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
'prune_num_stages_by': prune_num_stages,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
||||
# matrix multiplication
|
||||
@@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K,
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -BLOCK_K*SPLIT_K):
|
||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
@@ -117,10 +121,10 @@ class _matmul(torch.autograd.Function):
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
|
@@ -1,116 +1,121 @@
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
import heapq
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||
return tflops
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
|
||||
# time to compute
|
||||
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
|
||||
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
||||
compute_ms = total_ops / tput
|
||||
# time to compute
|
||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
active_cta_ratio = min(1, num_ctas/num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
|
||||
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
|
||||
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
|
||||
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram/dram_bw + total_l2/l2_bw
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
|
||||
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram / dram_bw + total_l2 / l2_bw
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram /store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram/reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M*N*2/(1024*1024)/store_bw
|
||||
store_ms += zero_ms
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram / store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram / reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
|
||||
store_ms += zero_ms
|
||||
|
||||
total_time_ms = max(compute_ms, load_ms) + store_ms
|
||||
if debug:
|
||||
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
total_time_ms = max(compute_ms, load_ms) + store_ms
|
||||
if debug:
|
||||
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
def prune_num_stages(configs):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
configs_map = {}
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||
|
||||
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||
if key in configs_map:
|
||||
configs_map[key].append((config, num_stages))
|
||||
else:
|
||||
configs_map[key] = [(config, num_stages)]
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
configs_map = {}
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
|
||||
mma_cycles = mmas/min(4, num_warps) * 8
|
||||
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||
if key in configs_map:
|
||||
configs_map[key].append((config, num_stages))
|
||||
else:
|
||||
configs_map[key] = [(config, num_stages)]
|
||||
|
||||
ldgsts_latency = 300 # Does this matter?
|
||||
optimal_num_stages = ldgsts_latency/mma_cycles
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
ldgsts_latency = 300 # Does this matter?
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
else: # Volta & Turing only supports num_stages <= 2
|
||||
random_config = v[0][0]
|
||||
random_config.num_stages = 2
|
||||
pruned_configs.append(random_config)
|
||||
return pruned_configs
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
else: # Volta & Turing only supports num_stages <= 2
|
||||
random_config = v[0][0]
|
||||
random_config.num_stages = 2
|
||||
pruned_configs.append(random_config)
|
||||
return pruned_configs
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import os
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
@@ -13,6 +14,7 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
res = kernel()
|
||||
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
|
||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||
# run function
|
||||
dtype = str(a.dtype).split('.')[-1]
|
||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka,\
|
||||
a.stride(0), a.stride(1),\
|
||||
b.stride(0), b.stride(1),\
|
||||
c.stride(0), c.stride(1),\
|
||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
M, N, Ka,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dtype, dtype, dtype,
|
||||
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
||||
|
||||
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||
import numpy.testing as npt
|
||||
if isinstance(x, torch.Tensor):
|
||||
@@ -93,6 +96,7 @@ def nvsmi(attrs):
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
@@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
# compute number of warmup and repeat
|
||||
n_warmup = max(1, int(warmup/estimate_ms))
|
||||
n_repeat = max(1, int(rep/estimate_ms))
|
||||
n_warmup = max(1, int(warmup / estimate_ms))
|
||||
n_repeat = max(1, int(rep / estimate_ms))
|
||||
# We maintain a buffer of 256 MB that we clear
|
||||
# before each kernel call to make sure that the L2
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
@@ -161,6 +165,7 @@ class Benchmark:
|
||||
"""
|
||||
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
@@ -224,9 +229,10 @@ class Mark:
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, save_path, show_plots, print_data):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
y_mean = bench.line_names
|
||||
y_min = [f'{x}-min' for x in bench.line_names]
|
||||
y_max = [f'{x}-max' for x in bench.line_names]
|
||||
@@ -259,7 +265,7 @@ class Mark:
|
||||
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(bench.ylabel)
|
||||
#ax.set_title(bench.plot_name)
|
||||
# ax.set_title(bench.plot_name)
|
||||
ax.set_xscale("log" if bench.x_log else "linear")
|
||||
ax.set_yscale("log" if bench.y_log else "linear")
|
||||
if show_plots:
|
||||
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_dram_gbps(backend=None, device=None):
|
||||
''' return DRAM bandwidth in GB/s '''
|
||||
# assert backend == CUDA
|
||||
@@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None):
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(backend, device):
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
# assume fp32 += fp16*fp16
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
ops_per_sub_core = 512
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
|
||||
return tflops
|
||||
|
@@ -21,8 +21,8 @@
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||
|
@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
|
||||
from triton.language.core import constexpr
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -38,7 +38,7 @@ def add_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extra elements in case
|
||||
# Load x and y from DRAM, masking out any extra elements in case
|
||||
# the input is not a multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
|
@@ -16,6 +16,8 @@ You will learn about:
|
||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
@@ -59,13 +61,10 @@ def naive_softmax(x):
|
||||
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
# memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
@@ -136,7 +135,7 @@ y_triton = softmax(x)
|
||||
y_torch = torch.softmax(x, axis=1)
|
||||
print(torch.allclose(y_triton, y_torch))
|
||||
|
||||
#%%
|
||||
# %%
|
||||
# As expected, the results are identical.
|
||||
|
||||
# %%
|
||||
@@ -187,5 +186,5 @@ benchmark.run(show_plots=True, print_data=True)
|
||||
# In the above plot, we can see that:
|
||||
#
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
|
@@ -112,13 +112,13 @@ You will specifically learn about:
|
||||
# # number of programs ids along the N axis
|
||||
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
# # number of programs in group
|
||||
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
# # id of the group this program is in
|
||||
# group_id = pid // num_pid_in_group
|
||||
# group_id = pid // num_pid_in_group
|
||||
# # row-id of the first program in the group
|
||||
# first_pid_m = group_id * GROUP_SIZE_M
|
||||
# first_pid_m = group_id * GROUP_SIZE_M
|
||||
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
# # *within groups*, programs are ordered in a column-major order
|
||||
# # row-id of the program in the *launch grid*
|
||||
# pid_m = first_pid_m + (pid % group_size_m)
|
||||
@@ -141,6 +141,7 @@ You will specifically learn about:
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -152,18 +153,19 @@ import triton.language as tl
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -185,7 +187,7 @@ def matmul_kernel(
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
@@ -196,16 +198,16 @@ def matmul_kernel(
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||
@@ -213,8 +215,8 @@ def matmul_kernel(
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix
|
||||
@@ -223,8 +225,8 @@ def matmul_kernel(
|
||||
# `accumulator` will be converted back to fp16 after the loop
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# this will access out-of-bounds memory and produce an
|
||||
# error or (worse!) incorrect results.
|
||||
a = tl.load(a_ptrs)
|
||||
@@ -236,7 +238,7 @@ def matmul_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32 !
|
||||
if meta['ACTIVATION']:
|
||||
if meta['ACTIVATION']:
|
||||
accumulator = meta['ACTIVATION'](accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
|
@@ -13,7 +13,7 @@ whose state is generally composed of a bit mask tensor of the same shape as the
|
||||
# %%
|
||||
# Baseline
|
||||
# -------------
|
||||
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
|
||||
# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance
|
||||
# of deep neural networks in low-data regime (i.e. regularization).
|
||||
#
|
||||
# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the
|
||||
@@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the
|
||||
|
||||
import tabulate
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dropout(
|
||||
x_ptr, # pointer to the input
|
||||
x_keep_ptr, # pointer to a mask of 0s and 1s
|
||||
output_ptr, # pointer to the output
|
||||
n_elements, # number of elements in the `x` tensor
|
||||
p, # probability that an element of `x` is changed to zero
|
||||
x_ptr, # pointer to the input
|
||||
x_keep_ptr, # pointer to a mask of 0s and 1s
|
||||
output_ptr, # pointer to the output
|
||||
n_elements, # number of elements in the `x` tensor
|
||||
p, # probability that an element of `x` is changed to zero
|
||||
**meta,
|
||||
):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
|
||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
# Input tensor
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Dropout mask
|
||||
@@ -88,7 +91,7 @@ print(tabulate.tabulate([
|
||||
# of persisting randomness across multiple invocations of the kernel.
|
||||
#
|
||||
# Pseudorandom number generation in Triton is simple! In this tutorial we will use the
|
||||
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
|
||||
# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32`
|
||||
# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides
|
||||
# other :ref:`random number generation strategies <Random Number Generation>`.
|
||||
#
|
||||
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
|
||||
#
|
||||
# Let's put it all together.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_dropout(
|
||||
x_ptr,
|
||||
|
@@ -4,15 +4,17 @@ Layer Normalization
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Forward Pass
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
@@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
# compute std
|
||||
xmean = tl.where(mask, x - mean, 0.)
|
||||
var = tl.sum(xmean * xmean, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
xhat = xmean*rstd
|
||||
var = tl.sum(xmean * xmean, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
xhat = xmean * rstd
|
||||
# write-back mean/rstd
|
||||
tl.store(M + row, mean)
|
||||
tl.store(V + row, rstd)
|
||||
@@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
stride, N, eps,
|
||||
**META):
|
||||
stride, N, eps,
|
||||
**META):
|
||||
GROUP_SIZE_M = META['GROUP_SIZE_M']
|
||||
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
X += row * stride
|
||||
DY += row * stride
|
||||
DX += row * stride
|
||||
# offset locks and weight/bias gradient pointer
|
||||
@@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id*N + cols
|
||||
DB = DB + lock_id*N + cols
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id * N + cols
|
||||
DB = DB + lock_id * N + cols
|
||||
# load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
# compute dx
|
||||
xhat = (x - mean)*rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy , 0.)
|
||||
xhat = (x - mean) * rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat*mean1 + mean2))*rstd
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy*xhat).to(w.dtype)
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = (dy).to(w.dtype)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
@@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None]*N + cols[None, :]
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols<N)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@@ -129,19 +134,19 @@ class LayerNorm(torch.autograd.Function):
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function):
|
||||
if N <= 4096: GROUP_SIZE_M = 128
|
||||
if N <= 1024: GROUP_SIZE_M = 256
|
||||
# allocate output
|
||||
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
@@ -166,14 +171,14 @@ class LayerNorm(torch.autograd.Function):
|
||||
M, N = x_arg.shape
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M = 32,
|
||||
BLOCK_SIZE_N = 128)
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
@@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1*torch.randn_like(x)
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# forward pass
|
||||
y_tri = layer_norm(x, w_shape, weight, bias, eps)
|
||||
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
@@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1*torch.randn_like(x)
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1 * torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
@@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
|
||||
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
|
||||
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user