[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton import triton
# ------------------------------- # -------------------------------
@@ -17,8 +18,8 @@ square_confs = [
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode, args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for AT in [False] for BT in [False] \ for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense'] for op_mode in ['dsd'] for layout_mode in ['dense']
] ]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1 Z, H = 1, 1
make_layout = { make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode] }[layout_mode]
# create layout # create layout
@@ -45,8 +46,8 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = { num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ 'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ 'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block 'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode] * 1e-12 }[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms) return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -66,7 +67,7 @@ square_confs = [
ylabel='GBPS', ylabel='GBPS',
plot_name=f'{layout_mode}-square', plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for layout_mode in ['dense', 'tril'] for layout_mode in ['dense', 'tril']
] ]

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton import triton
confs = [ confs = [
@@ -11,7 +12,7 @@ confs = [
ylabel='GBPS', ylabel='GBPS',
plot_name=f'{mode}-2048', plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode} args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\ )
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
@@ -24,7 +25,7 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9) num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3 gbps = lambda ms: num_gb / ms * 1e3
# forward pass # forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider] 'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward': if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))

View File

@@ -1,6 +1,6 @@
import triton
import torch import torch
import os
import triton
def rounded_linspace(low, high, steps, div): def rounded_linspace(low, high, steps, div):
@@ -36,8 +36,8 @@ transformer_confs = [
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16} args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\ ) for NK in [12288]
for i, x in enumerate(["N", "K"])\ for i, x in enumerate(["N", "K"])
for M in [2048] for M in [2048]
] ]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t() if AT:
if BT: b = b.t() a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9 tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas": if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try: try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms) return tflops(ms), tflops(max_ms), tflops(min_ms)
except: except Exception:
return None return None
return None return None

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import sys
import os
import inspect import inspect
import os
import sys
import triton import triton

View File

@@ -1,20 +1,19 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn import distutils.spawn
import urllib.request import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm(): def get_llvm():
# tries to find system LLVM # tries to find system LLVM
@@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir): if not os.path.exists(llvm_library_dir):
try: try:
shutil.rmtree(os.path.join(dir, name)) shutil.rmtree(os.path.join(dir, name))
except: except Exception:
pass pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...') print('downloading and extracting ' + url + '...')

View File

@@ -1,14 +1,18 @@
from numpy import record import triton.language as tl
import torch
import triton
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch
from numpy import record
import triton
####################### #######################
# Utilities # Utilities
####################### #######################
def nvsmi(attrs): def nvsmi(attrs):
attrs = ','.join(attrs) attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -46,6 +50,8 @@ matmul_data = {
# (256 , 256 , 8192 ) : {'v100': 0.}, # (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.}, # (256 , 256 , 32768) : {'v100': 0.},
} }
@pytest.mark.parametrize('M, N, K', matmul_data.keys()) @pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K): def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100'] ref_gpu_util = matmul_data[(M, N, K)]['v100']
@@ -61,10 +67,11 @@ def test_matmul(M, N, K):
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
####################### #######################
# Element-Wise # Element-Wise
####################### #######################
import triton.language as tl
@triton.jit @triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, def _add(x_ptr, y_ptr, output_ptr, n_elements,
@@ -89,6 +96,7 @@ elementwise_data = {
1024 * 65536: {'v100': 0.939}, 1024 * 65536: {'v100': 0.939},
} }
@pytest.mark.parametrize('N', elementwise_data.keys()) @pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N): def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100'] ref_gpu_util = elementwise_data[N]['v100']
@@ -105,4 +113,3 @@ def test_elementwise(N):
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, SIZE: tl.constexpr): def kernel(X, SIZE: tl.constexpr):
pass pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, SIZE: tl.constexpr): def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr): def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE) off = tl.arange(0, SIZE)
@@ -206,6 +209,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# --------------- # ---------------
# test binary ops # test binary ops
# --------------- # ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ @pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op) (dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%'] for op in ['+', '-', '*', '/', '%']
@@ -298,16 +303,18 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops # test compare ops
# --------------- # ---------------
ops = ['==', '!=', '>', '<', '>=', '<='] ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real # real
[ [
(dtype_x, dtype_y, op, 'real', 'real') \ (dtype_x, dtype_y, op, 'real', 'real')
for op in ops \ for op in ops
for dtype_x in dtypes \ for dtype_x in dtypes
for dtype_y in dtypes for dtype_y in dtypes
] + \ ] +
# NaNs # NaNs
[('float32', 'float32', op, mode_x, mode_y) \ [('float32', 'float32', op, mode_x, mode_y)
for op in ops for op in ops
for mode_x, mode_y in [('nan', 'real'), for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'), ('real', 'nan'),
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin' # 'exp', 'log', 'cos', 'sin'
# ]) # ])
@pytest.mark.parametrize("expr", [ @pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin' 'exp', 'log', 'cos', 'sin'
]) ])
@@ -558,9 +566,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# --------------- # ---------------
# test reduce # test reduce
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape", @pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \ [(dtype, shape)
for dtype in dtypes\ for dtype in dtypes
for shape in [128, 512]]) for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'): def test_reduce1d(dtype_str, shape, device='cuda'):
@@ -608,10 +618,12 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# --------------- # ---------------
# test permute # test permute
# --------------- # ---------------
@pytest.mark.parametrize("dtype_str, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \ [(dtype, shape, perm)
for dtype in ['float32']\ for dtype in ['float32']
for shape in [(128, 128)]\ for shape in [(128, 128)]
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot # test dot
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'): def test_dot(epilogue, device='cuda'):
# triton kernel # triton kernel
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
def test_dot_without_load(): def test_dot_without_load():
@triton.jit @triton.jit
def kernel(out): def kernel(out):
@@ -723,10 +737,12 @@ def test_dot_without_load():
# test arange # test arange
# --------------- # ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'): def test_arange(start, device='cuda'):
BLOCK = 128 BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit @triton.jit
def _kernel(z, BLOCK: tl.constexpr, def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr): START: tl.constexpr, END: tl.constexpr):
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# --------------- # ---------------
# 'bfloat16': torch.bfloat16, # 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run. # Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32 M = 32
@@ -788,6 +806,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
reference_out = torch.matmul(in1, in2) reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out) triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache): def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda') src = torch.empty(128, device='cuda')
@@ -831,10 +850,13 @@ def test_load_cache_modifier(cache):
# test default # test default
# --------------- # ---------------
# TODO: can't be local to test_default # TODO: can't be local to test_default
@triton.jit @triton.jit
def _impl(value=10): def _impl(value=10):
return value return value
def test_default(): def test_default():
value = 5 value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -852,6 +874,8 @@ def test_default():
# --------------- # ---------------
# test noop # test noop
# ---------------- # ----------------
def test_noop(device='cuda'): def test_noop(device='cuda'):
@triton.jit @triton.jit
def kernel(x): def kernel(x):

View File

@@ -1,16 +1,17 @@
import torch import numpy as np
import triton
import triton.language as tl
import pytest import pytest
import scipy.stats import scipy.stats
import numpy as np import torch
from numpy.random import Philox from numpy.random import Philox
import triton
import triton.language as tl
##################################### #####################################
## Reference Philox Implementation # Reference Philox Implementation
##################################### #####################################
class PhiloxConfig: class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
##################################### #####################################
## Unit Tests # Unit Tests
##################################### #####################################
BLOCK = 1024 BLOCK = 1024
# test generation of random uint32 # test generation of random uint32
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\ [(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
) )
def test_randint(size, seed, device='cuda'): def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(','))) size = list(map(int, size.split(',')))
@triton.jit @triton.jit
def kernel(X, N, seed): def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,8 +136,10 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref assert out_tri == out_ref
# test uniform PRNG # test uniform PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_rand(size, seed, device='cuda'): def test_rand(size, seed, device='cuda'):
@@ -151,8 +157,10 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG # test normal PRNG
@pytest.mark.parametrize('size, seed', @pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\ [(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]] for seed in [0, 42, 124, 54]]
) )
def test_randn(size, seed, device='cuda'): def test_randn(size, seed, device='cuda'):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result # torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx # broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask) if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx) M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf") rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :] # rx += kp_mask[:, None, None, :]

View File

@@ -1,12 +1,14 @@
import torch
import triton
import pytest import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode", @pytest.mark.parametrize("M, N, dtype, mode",
[ [
(M, N, dtype, mode) for M in [1024, 821] (M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000] for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\ for dtype in ['float16', 'float32']
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
) )

View File

@@ -1,8 +1,10 @@
import pytest
import itertools import itertools
import triton
import pytest
import torch import torch
import triton
@pytest.mark.parametrize( @pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os import os
import shutil import shutil
import pytest import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp" tmpdir = ".tmp"
@triton.jit @triton.jit
def function_1(i): def function_1(i):
i = i + 1 i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1 i = i + 1
return i return i
@triton.jit @triton.jit
def kernel(X, i, BLOCK: tl.constexpr): def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
@triton.jit(do_not_specialize=["i"]) @triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr): def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1 i = i + 1
i = function_1(i) i = function_1(i)
tl.store(X, i) tl.store(X, i)
def apply_src_change(target, old, new): def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash') delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash') delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old) target.src = target.src.replace(new, old)
return ret return ret
def test_nochange(): def test_nochange():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1') updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated assert baseline == updated
def test_toplevel_change(): def test_toplevel_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2') updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def test_nested1_change(): def test_nested1_change():
baseline = kernel.cache_key baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2') updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated assert baseline != updated
def reset_tmp_dir(): def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir): if os.path.exists(tmpdir):
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def test_reuse(): def test_reuse():
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1
@@ -78,6 +90,7 @@ def test_reuse():
@pytest.mark.parametrize('mode', ['enable', 'disable']) @pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode): def test_specialize(mode):
counter = 0 counter = 0
def inc_counter(key, binary, repr): def inc_counter(key, binary, repr):
nonlocal counter nonlocal counter
counter += 1 counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess import subprocess
import triton.language as tl
import numpy as np import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix(): def get_p2p_matrix():

View File

@@ -1,26 +1,26 @@
import ast import ast
import builtins import builtins
import dbm
import functools import functools
import inspect
import struct
import sys
import textwrap
import hashlib import hashlib
import inspect
import os import os
import pickle import pickle
import struct
import subprocess import subprocess
import os import sys
import tempfile
import textwrap
import time
import warnings import warnings
from .tools.disasm import extract from typing import Dict, Optional
import torch import torch
from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from filelock import FileLock from .tools.disasm import extract
import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
@@ -135,7 +135,6 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(fn.args[idx]) arg_values.append(fn.args[idx])
idx += 1 idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
if inline: if inline:
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
# default: call visit_Assign # default: call visit_Assign
return self.visit_Assign(node) return self.visit_Assign(node)
def visit_Assign(self, node): def visit_Assign(self, node):
_names = [] _names = []
for target in node.targets: for target in node.targets:
@@ -404,9 +402,9 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),\ self.visit(pos_cond_node),
self.visit(neg_cond_node),\ self.visit(neg_cond_node),
_builder=self.builder) _builder=self.builder)
#cond_node = neg_cond_node #cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
@@ -632,10 +630,14 @@ class Kernel:
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
if N % 16 == 0: return 16 if N % 16 == 0:
if N % 8 == 0: return 8 return 16
if N % 4 == 0: return 4 if N % 8 == 0:
if N % 2 == 0: return 2 return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1 return 1
def __init__(self, fn): def __init__(self, fn):
@@ -675,7 +677,7 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes # attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize} if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
@@ -768,7 +770,6 @@ class Launcher:
return self.kernel(*wargs, **kwargs, grid=self.grid) return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner: class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
''' '''
@@ -788,6 +789,7 @@ class Autotuner:
self.hook = lambda args: 0 self.hook = lambda args: 0
if reset_to_zero is not None: if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero] self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args): def _hook(args):
for i in self.reset_idx: for i in self.reset_idx:
args[i].zero_() args[i].zero_()
@@ -814,6 +816,7 @@ class Autotuner:
) )
# augment meta-parameters with tunable ones # augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs) current = dict(meta, **config.kwargs)
def kernel_call(): def kernel_call():
if config.pre_hook: if config.pre_hook:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
@@ -838,7 +841,7 @@ class Autotuner:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time() bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \ timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs} for config in pruned_configs}
bench_end = time.time() bench_end = time.time()
self.bench_time = bench_end - bench_start self.bench_time = bench_end - bench_start
@@ -876,7 +879,7 @@ def version_key():
ptxas_version = '' ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3 # 3
class DependenciesFinder(ast.NodeVisitor): class DependenciesFinder(ast.NodeVisitor):
@@ -917,11 +920,11 @@ class DependenciesFinder(ast.NodeVisitor):
self.ret = (self.ret + func.hash).encode("utf-8") self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest() self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction: class JITFunction:
cache_hook = None cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None): def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
@@ -946,7 +949,6 @@ class JITFunction:
# forward docs # forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
@property @property
@functools.lru_cache() @functools.lru_cache()
def cache_key(self): def cache_key(self):
@@ -1027,6 +1029,7 @@ class Config:
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args. function are args.
""" """
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs self.kwargs = kwargs
self.num_warps = num_warps self.num_warps = num_warps
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
def cdiv(x, y): def cdiv(x, y):
return (x + y - 1) // y return (x + y - 1) // y
def next_power_of_2(n): def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n""" """Return the smallest power of 2 greater than or equal to n"""
n -= 1 n -= 1
@@ -1163,6 +1167,7 @@ def next_power_of_2(n):
###### ######
class TensorWrapper: class TensorWrapper:
def __init__(self, base, dtype): def __init__(self, base, dtype):
self.dtype = dtype self.dtype = dtype

View File

@@ -1,8 +1,8 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps from functools import wraps
import triton
from triton._C.libtriton.triton import frontend, ir
# convert block/dtype to ir values # convert block/dtype to ir values
def _to_ir(x, builder): def _to_ir(x, builder):
@@ -111,6 +111,7 @@ class pointer_dtype:
def __str__(self): def __str__(self):
return f'pointer<{self.element_ty}>' return f'pointer<{self.element_ty}>'
# scalar types # scalar types
int1 = dtype(ir.type.get_int1) int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8) int8 = dtype(ir.type.get_int8)
@@ -489,6 +490,7 @@ def broadcast_to(input, shape, _builder=None):
""" """
return frontend.broadcast_to(input, shape, _builder) return frontend.broadcast_to(input, shape, _builder)
@builtin @builtin
def cat(input, other, _builder=None): def cat(input, other, _builder=None):
""" """
@@ -603,6 +605,7 @@ def _add_atomic_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_atomic_docstr("compare-and-swap") @_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None): def atomic_cas(pointer, cmp, val, _builder=None):
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_xchg(pointer, val, mask=None, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder) return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin @builtin
@_add_atomic_docstr("add") @_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None): def atomic_add(pointer, val, mask=None, _builder=None):
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
def umulhi(x, y, _builder=None): def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder) return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name): def _add_math_1arg_docstr(name):
def _decorator(func): def _decorator(func):
@@ -697,21 +702,25 @@ def _add_math_1arg_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_math_1arg_docstr("exponential") @_add_math_1arg_docstr("exponential")
def exp(x, _builder=None): def exp(x, _builder=None):
return frontend.exp(x, _builder) return frontend.exp(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("natural logarithm") @_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None): def log(x, _builder=None):
return frontend.log(x, _builder) return frontend.log(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("cosine") @_add_math_1arg_docstr("cosine")
def cos(x, _builder=None): def cos(x, _builder=None):
return frontend.cos(x, _builder) return frontend.cos(x, _builder)
@builtin @builtin
@_add_math_1arg_docstr("sine") @_add_math_1arg_docstr("sine")
def sin(x, _builder=None): def sin(x, _builder=None):
@@ -742,6 +751,7 @@ def _add_reduction_docstr(name):
return _decorator return _decorator
@builtin @builtin
@_add_reduction_docstr("maximum") @_add_reduction_docstr("maximum")
def max(input, axis, _builder=None): def max(input, axis, _builder=None):
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
def sum(input, axis, _builder=None): def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder) return frontend.sum(input, axis, _builder)
@builtin @builtin
@_add_reduction_docstr("xor sum") @_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None): def xor_sum(input, axis, _builder=None):
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
def abs(x): def abs(x):
return where(x >= 0, x, -x) return where(x >= 0, x, -x)
@triton.jit @triton.jit
def cdiv(x, div): def cdiv(x, div):
""" """
@@ -871,6 +883,7 @@ def ravel(x):
""" """
return triton.language.reshape(x, [x.type.numel]) return triton.language.reshape(x, [x.type.numel])
@triton.jit @triton.jit
def swizzle2d(i, j, size_i, size_j, size_g): def swizzle2d(i, j, size_i, size_j, size_g):
""" """

View File

@@ -1,7 +1,6 @@
import triton import triton
from . import core as tl from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53

View File

@@ -1,4 +1,4 @@
#from .conv import _conv, conv #from .conv import _conv, conv
from .matmul import _matmul, matmul
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul

View File

@@ -1,6 +1,7 @@
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
# ******************************************************** # ********************************************************
# -------------------------------------------------------- # --------------------------------------------------------
@@ -11,6 +12,7 @@ import torch
# -------------------------------------------------------- # --------------------------------------------------------
# ******************************************************** # ********************************************************
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
}) })
@@ -37,17 +39,17 @@ def _sdd_kernel(
start_am = tl.load(lut + 1) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za \ a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha \ + off_h * stride_ha
+ offs_am[:, None] * stride_ma \ + offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak) + offs_ak[None, :] * stride_ak)
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(lut + 2) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb \ b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb \ + off_h * stride_hb
+ offs_bn[None, :] * stride_nb \ + offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk) + offs_bk[:, None] * stride_bk)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
@@ -69,12 +71,13 @@ def _sdd_kernel(
## ---------------- ## ## ---------------- ##
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc \ pc = C + (off_z * stride_zc
+ block_id * stride_hc \ + block_id * stride_hc
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc) + offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()
@@ -119,6 +122,8 @@ def sdd_lut(layout, block, device):
# This operation uses a look-up table that contains pre-computed pointer increments # This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel. # in order to minimize computations in the inner loop of the matmul kernel.
# ----------------------------- # -----------------------------
@triton.jit @triton.jit
def _dsd_kernel( def _dsd_kernel(
A, B, C, A, B, C,
@@ -193,6 +198,7 @@ def _dsd_kernel(
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0) tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1: if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous() a = a.contiguous()

View File

@@ -1,7 +1,8 @@
import triton.language as tl
import triton
import torch import torch
import triton
import triton.language as tl
def num_warps(n): def num_warps(n):
if n < 512: if n < 512:
@@ -161,7 +162,7 @@ class _softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M] grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK=block, BLOCK=block,
APPLY_SCALE=apply_scale, APPLY_SCALE=apply_scale,

View File

@@ -1,7 +1,9 @@
import os import os
import torch
import triton import triton
import triton.language as tl import triton.language as tl
import torch
def next_power_of_2(n): def next_power_of_2(n):

View File

@@ -1,11 +1,14 @@
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
from .matmul_perf_model import * from .matmul_perf_model import *
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
@@ -22,6 +25,7 @@ def get_configs_io_bound():
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs return configs
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
}) })

View File

@@ -1,8 +1,11 @@
import heapq
import torch import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps): def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS ''' ''' return compute throughput in TOPS '''
@@ -11,6 +14,7 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops return tflops
def estimate_matmul_time( def estimate_matmul_time(
# backend, device, # backend, device,
num_warps, num_stages, num_warps, num_stages,
@@ -73,6 +77,7 @@ def estimate_matmul_time(
f'Activate CTAs: {active_cta_ratio*100}%') f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms return total_time_ms
def prune_num_stages(configs): def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
@@ -104,7 +109,7 @@ def prune_num_stages(configs):
optimal_num_stages = ldgsts_latency / mma_cycles optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages # nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest: for n in nearest:

View File

@@ -1,10 +1,11 @@
import torch
import os import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess import subprocess
import sys import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
try: try:
import triton._C.libtriton.cutlass as _cutlass import triton._C.libtriton.cutlass as _cutlass
@@ -13,6 +14,7 @@ except ImportError:
_cutlass = None _cutlass = None
has_cutlass = False has_cutlass = False
def catch_oor(kernel, pytest_handle=None): def catch_oor(kernel, pytest_handle=None):
try: try:
res = kernel() res = kernel()
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function # run function
dtype = str(a.dtype).split('.')[-1] dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,\ M, N, Ka,
a.stride(0), a.stride(1),\ a.stride(0), a.stride(1),
b.stride(0), b.stride(1),\ b.stride(0), b.stride(1),
c.stride(0), c.stride(1),\ c.stride(0), c.stride(1),
dtype, dtype, dtype, dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream) a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''): def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt import numpy.testing as npt
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@@ -93,6 +96,7 @@ def nvsmi(attrs):
ret = [int(x) for x in ret] ret = [int(x) for x in ret]
return ret return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
""" """
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -161,6 +165,7 @@ class Benchmark:
""" """
This class is used by the :code:`perf_report` function to generate line plots with a concise API. This class is used by the :code:`perf_report` function to generate line plots with a concise API.
""" """
def __init__( def __init__(
self, self,
x_names, x_names,
@@ -224,9 +229,10 @@ class Mark:
self.benchmarks = benchmarks self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data): def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os
y_mean = bench.line_names y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names] y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names]
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper
def get_dram_gbps(backend=None, device=None): def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s ''' ''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA # assert backend == CUDA
@@ -309,6 +316,7 @@ def get_dram_gbps(backend=None, device=None):
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps return bw_gbps
def get_max_tensorcore_tflops(backend, device): def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz

View File

@@ -21,8 +21,8 @@
# SOFTWARE. # SOFTWARE.
import argparse import argparse
import subprocess
import re import re
import subprocess
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')

View File

@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel # Compute Kernel
# -------------------------- # --------------------------
from triton.language.core import constexpr
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -16,6 +16,8 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation: # Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch import torch
@@ -59,9 +61,6 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the # power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes: # memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit @triton.jit
def softmax_kernel( def softmax_kernel(

View File

@@ -141,6 +141,7 @@ You will specifically learn about:
# #
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -152,6 +153,7 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the # - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs # provided configs
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),

View File

@@ -30,9 +30,11 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate import tabulate
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def _dropout( def _dropout(
x_ptr, # pointer to the input x_ptr, # pointer to the input
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output return output
# Input tensor # Input tensor
x = torch.randn(size=(10,)).cuda() x = torch.randn(size=(10,)).cuda()
# Dropout mask # Dropout mask
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
# #
# Let's put it all together. # Let's put it all together.
@triton.jit @triton.jit
def _seeded_dropout( def _seeded_dropout(
x_ptr, x_ptr,

View File

@@ -4,8 +4,10 @@ Layer Normalization
""" """
import torch import torch
import triton.language as tl
import triton import triton
import triton.language as tl
# Forward Pass # Forward Pass
@triton.jit @triton.jit
@@ -97,6 +99,8 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
tl.atomic_xchg(Lock, 0) tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB) # Backward pass (total DW + total DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0) pid = tl.program_id(0)
@@ -116,6 +120,7 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N) tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function): class LayerNorm(torch.autograd.Function):
@staticmethod @staticmethod
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['N'], x_names=['N'],
@@ -248,4 +254,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
grad_to_none=[x], rep=500) grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True) bench_layer_norm.run(save_path='.', print_data=True)