[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
@@ -17,8 +18,8 @@ square_confs = [
|
|||||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||||
)\
|
)
|
||||||
for AT in [False] for BT in [False] \
|
for AT in [False] for BT in [False]
|
||||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ square_confs = [
|
|||||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||||
Z, H = 1, 1
|
Z, H = 1, 1
|
||||||
make_layout = {
|
make_layout = {
|
||||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||||
}[layout_mode]
|
}[layout_mode]
|
||||||
# create layout
|
# create layout
|
||||||
@@ -45,8 +46,8 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
|||||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||||
num_flops = {
|
num_flops = {
|
||||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||||
}[op_mode] * 1e-12
|
}[op_mode] * 1e-12
|
||||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||||
@@ -66,7 +67,7 @@ square_confs = [
|
|||||||
ylabel='GBPS',
|
ylabel='GBPS',
|
||||||
plot_name=f'{layout_mode}-square',
|
plot_name=f'{layout_mode}-square',
|
||||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||||
)\
|
)
|
||||||
for layout_mode in ['dense', 'tril']
|
for layout_mode in ['dense', 'tril']
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
confs = [
|
confs = [
|
||||||
@@ -11,7 +12,7 @@ confs = [
|
|||||||
ylabel='GBPS',
|
ylabel='GBPS',
|
||||||
plot_name=f'{mode}-2048',
|
plot_name=f'{mode}-2048',
|
||||||
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||||
)\
|
)
|
||||||
for mode in ['forward', 'backward']
|
for mode in ['forward', 'backward']
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -24,7 +25,7 @@ def bench_op(M, N, dtype, mode, provider):
|
|||||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||||
gbps = lambda ms: num_gb / ms * 1e3
|
gbps = lambda ms: num_gb / ms * 1e3
|
||||||
# forward pass
|
# forward pass
|
||||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
|
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
||||||
'triton': triton.ops.cross_entropy}[provider]
|
'triton': triton.ops.cross_entropy}[provider]
|
||||||
if mode == 'forward':
|
if mode == 'forward':
|
||||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import triton
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
def rounded_linspace(low, high, steps, div):
|
def rounded_linspace(low, high, steps, div):
|
||||||
@@ -36,8 +36,8 @@ transformer_confs = [
|
|||||||
ylabel="TFLOPS",
|
ylabel="TFLOPS",
|
||||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||||
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||||
) for NK in [12288]\
|
) for NK in [12288]
|
||||||
for i, x in enumerate(["N", "K"])\
|
for i, x in enumerate(["N", "K"])
|
||||||
for M in [2048]
|
for M in [2048]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -46,8 +46,10 @@ transformer_confs = [
|
|||||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||||
if AT: a = a.t()
|
if AT:
|
||||||
if BT: b = b.t()
|
a = a.t()
|
||||||
|
if BT:
|
||||||
|
b = b.t()
|
||||||
num_flops = 2 * M * N * K
|
num_flops = 2 * M * N * K
|
||||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||||
if provider == "cublas":
|
if provider == "cublas":
|
||||||
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
|||||||
try:
|
try:
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||||
except:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,20 +1,19 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import sysconfig
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import distutils
|
import distutils
|
||||||
import glob
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
from distutils.version import LooseVersion
|
|
||||||
from setuptools import setup, Extension, find_packages
|
|
||||||
from setuptools.command.build_ext import build_ext
|
|
||||||
from setuptools.command.test import test as TestCommand
|
|
||||||
import distutils.spawn
|
import distutils.spawn
|
||||||
import urllib.request
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import urllib.request
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
from setuptools import Extension, setup
|
||||||
|
from setuptools.command.build_ext import build_ext
|
||||||
|
|
||||||
|
|
||||||
def get_llvm():
|
def get_llvm():
|
||||||
# tries to find system LLVM
|
# tries to find system LLVM
|
||||||
@@ -32,7 +31,7 @@ def get_llvm():
|
|||||||
if not os.path.exists(llvm_library_dir):
|
if not os.path.exists(llvm_library_dir):
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(os.path.join(dir, name))
|
shutil.rmtree(os.path.join(dir, name))
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||||
print('downloading and extracting ' + url + '...')
|
print('downloading and extracting ' + url + '...')
|
||||||
|
@@ -1,14 +1,18 @@
|
|||||||
from numpy import record
|
import triton.language as tl
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
from numpy import record
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
#######################
|
#######################
|
||||||
# Utilities
|
# Utilities
|
||||||
#######################
|
#######################
|
||||||
|
|
||||||
|
|
||||||
def nvsmi(attrs):
|
def nvsmi(attrs):
|
||||||
attrs = ','.join(attrs)
|
attrs = ','.join(attrs)
|
||||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||||
@@ -46,6 +50,8 @@ matmul_data = {
|
|||||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||||
# (256 , 256 , 32768) : {'v100': 0.},
|
# (256 , 256 , 32768) : {'v100': 0.},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||||
def test_matmul(M, N, K):
|
def test_matmul(M, N, K):
|
||||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||||
@@ -61,10 +67,11 @@ def test_matmul(M, N, K):
|
|||||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
#######################
|
#######################
|
||||||
# Element-Wise
|
# Element-Wise
|
||||||
#######################
|
#######################
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||||
@@ -89,6 +96,7 @@ elementwise_data = {
|
|||||||
1024 * 65536: {'v100': 0.939},
|
1024 * 65536: {'v100': 0.939},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||||
def test_elementwise(N):
|
def test_elementwise(N):
|
||||||
ref_gpu_util = elementwise_data[N]['v100']
|
ref_gpu_util = elementwise_data[N]['v100']
|
||||||
@@ -105,4 +113,3 @@ def test_elementwise(N):
|
|||||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||||
|
|
||||||
|
@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
|
|||||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||||
def test_empty_kernel(dtype_x, device='cuda'):
|
def test_empty_kernel(dtype_x, device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, SIZE: tl.constexpr):
|
def kernel(X, SIZE: tl.constexpr):
|
||||||
pass
|
pass
|
||||||
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
|||||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, SIZE: tl.constexpr):
|
def kernel(Z, X, SIZE: tl.constexpr):
|
||||||
off = tl.arange(0, SIZE)
|
off = tl.arange(0, SIZE)
|
||||||
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
|||||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
||||||
off = tl.arange(0, SIZE)
|
off = tl.arange(0, SIZE)
|
||||||
@@ -206,6 +209,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test binary ops
|
# test binary ops
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
(dtype_x, dtype_y, op)
|
(dtype_x, dtype_y, op)
|
||||||
for op in ['+', '-', '*', '/', '%']
|
for op in ['+', '-', '*', '/', '%']
|
||||||
@@ -298,16 +303,18 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
# test compare ops
|
# test compare ops
|
||||||
# ---------------
|
# ---------------
|
||||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
||||||
# real
|
# real
|
||||||
[
|
[
|
||||||
(dtype_x, dtype_y, op, 'real', 'real') \
|
(dtype_x, dtype_y, op, 'real', 'real')
|
||||||
for op in ops \
|
for op in ops
|
||||||
for dtype_x in dtypes \
|
for dtype_x in dtypes
|
||||||
for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
] + \
|
] +
|
||||||
# NaNs
|
# NaNs
|
||||||
[('float32', 'float32', op, mode_x, mode_y) \
|
[('float32', 'float32', op, mode_x, mode_y)
|
||||||
for op in ops
|
for op in ops
|
||||||
for mode_x, mode_y in [('nan', 'real'),
|
for mode_x, mode_y in [('nan', 'real'),
|
||||||
('real', 'nan'),
|
('real', 'nan'),
|
||||||
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
|||||||
# 'exp', 'log', 'cos', 'sin'
|
# 'exp', 'log', 'cos', 'sin'
|
||||||
# ])
|
# ])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expr", [
|
@pytest.mark.parametrize("expr", [
|
||||||
'exp', 'log', 'cos', 'sin'
|
'exp', 'log', 'cos', 'sin'
|
||||||
])
|
])
|
||||||
@@ -558,9 +566,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test reduce
|
# test reduce
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape",
|
@pytest.mark.parametrize("dtype_str, shape",
|
||||||
[(dtype, shape) \
|
[(dtype, shape)
|
||||||
for dtype in dtypes\
|
for dtype in dtypes
|
||||||
for shape in [128, 512]])
|
for shape in [128, 512]])
|
||||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||||
|
|
||||||
@@ -608,10 +618,12 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test permute
|
# test permute
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||||
[(dtype, shape, perm) \
|
[(dtype, shape, perm)
|
||||||
for dtype in ['float32']\
|
for dtype in ['float32']
|
||||||
for shape in [(128, 128)]\
|
for shape in [(128, 128)]
|
||||||
for perm in [(1, 0)]])
|
for perm in [(1, 0)]])
|
||||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||||
|
|
||||||
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# test dot
|
# test dot
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||||
def test_dot(epilogue, device='cuda'):
|
def test_dot(epilogue, device='cuda'):
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
|
||||||
|
|
||||||
def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(out):
|
def kernel(out):
|
||||||
@@ -723,10 +737,12 @@ def test_dot_without_load():
|
|||||||
# test arange
|
# test arange
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||||
def test_arange(start, device='cuda'):
|
def test_arange(start, device='cuda'):
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(z, BLOCK: tl.constexpr,
|
def _kernel(z, BLOCK: tl.constexpr,
|
||||||
START: tl.constexpr, END: tl.constexpr):
|
START: tl.constexpr, END: tl.constexpr):
|
||||||
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# 'bfloat16': torch.bfloat16,
|
# 'bfloat16': torch.bfloat16,
|
||||||
# Testing masked loads with an intermate copy to shared memory run.
|
# Testing masked loads with an intermate copy to shared memory run.
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||||
M = 32
|
M = 32
|
||||||
@@ -788,6 +806,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
|||||||
reference_out = torch.matmul(in1, in2)
|
reference_out = torch.matmul(in1, in2)
|
||||||
triton.testing.allclose(out, reference_out)
|
triton.testing.allclose(out, reference_out)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||||
def test_load_cache_modifier(cache):
|
def test_load_cache_modifier(cache):
|
||||||
src = torch.empty(128, device='cuda')
|
src = torch.empty(128, device='cuda')
|
||||||
@@ -831,10 +850,13 @@ def test_load_cache_modifier(cache):
|
|||||||
# test default
|
# test default
|
||||||
# ---------------
|
# ---------------
|
||||||
# TODO: can't be local to test_default
|
# TODO: can't be local to test_default
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _impl(value=10):
|
def _impl(value=10):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def test_default():
|
def test_default():
|
||||||
value = 5
|
value = 5
|
||||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||||
@@ -852,6 +874,8 @@ def test_default():
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test noop
|
# test noop
|
||||||
# ----------------
|
# ----------------
|
||||||
|
|
||||||
|
|
||||||
def test_noop(device='cuda'):
|
def test_noop(device='cuda'):
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x):
|
def kernel(x):
|
||||||
|
@@ -1,16 +1,17 @@
|
|||||||
import torch
|
import numpy as np
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
import pytest
|
import pytest
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
from numpy.random import Philox
|
from numpy.random import Philox
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
## Reference Philox Implementation
|
# Reference Philox Implementation
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
|
|
||||||
class PhiloxConfig:
|
class PhiloxConfig:
|
||||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||||
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
|
|||||||
|
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
## Unit Tests
|
# Unit Tests
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
BLOCK = 1024
|
BLOCK = 1024
|
||||||
|
|
||||||
# test generation of random uint32
|
# test generation of random uint32
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('size, seed',
|
@pytest.mark.parametrize('size, seed',
|
||||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
[(size, seed) for size in ['10', '4,53', '10000']
|
||||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||||
)
|
)
|
||||||
def test_randint(size, seed, device='cuda'):
|
def test_randint(size, seed, device='cuda'):
|
||||||
size = list(map(int, size.split(',')))
|
size = list(map(int, size.split(',')))
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, N, seed):
|
def kernel(X, N, seed):
|
||||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||||
@@ -132,8 +136,10 @@ def test_randint(size, seed, device='cuda'):
|
|||||||
assert out_tri == out_ref
|
assert out_tri == out_ref
|
||||||
|
|
||||||
# test uniform PRNG
|
# test uniform PRNG
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('size, seed',
|
@pytest.mark.parametrize('size, seed',
|
||||||
[(size, seed) for size in [1000000]\
|
[(size, seed) for size in [1000000]
|
||||||
for seed in [0, 42, 124, 54]]
|
for seed in [0, 42, 124, 54]]
|
||||||
)
|
)
|
||||||
def test_rand(size, seed, device='cuda'):
|
def test_rand(size, seed, device='cuda'):
|
||||||
@@ -151,8 +157,10 @@ def test_rand(size, seed, device='cuda'):
|
|||||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||||
|
|
||||||
# test normal PRNG
|
# test normal PRNG
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('size, seed',
|
@pytest.mark.parametrize('size, seed',
|
||||||
[(size, seed) for size in [1000000]\
|
[(size, seed) for size in [1000000]
|
||||||
for seed in [0, 42, 124, 54]]
|
for seed in [0, 42, 124, 54]]
|
||||||
)
|
)
|
||||||
def test_randn(size, seed, device='cuda'):
|
def test_randn(size, seed, device='cuda'):
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||||
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
|
|||||||
# torch result
|
# torch result
|
||||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||||
# broadcast at_mask to the same shape as rx
|
# broadcast at_mask to the same shape as rx
|
||||||
if is_causal: at_mask = torch.tril(at_mask)
|
if is_causal:
|
||||||
|
at_mask = torch.tril(at_mask)
|
||||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||||
rx[M == 0] = float("-inf")
|
rx[M == 0] = float("-inf")
|
||||||
# rx += kp_mask[:, None, None, :]
|
# rx += kp_mask[:, None, None, :]
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||||
[
|
[
|
||||||
(M, N, dtype, mode) for M in [1024, 821]
|
(M, N, dtype, mode) for M in [1024, 821]
|
||||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||||
for dtype in ['float16', 'float32']\
|
for dtype in ['float16', 'float32']
|
||||||
for mode in ['forward', 'backward']
|
for mode in ['forward', 'backward']
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
import pytest
|
|
||||||
import itertools
|
import itertools
|
||||||
import triton
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||||
|
@@ -1,13 +1,16 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
from triton.code_gen import JITFunction
|
|
||||||
import triton.language as tl
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.code_gen import JITFunction
|
||||||
|
|
||||||
tmpdir = ".tmp"
|
tmpdir = ".tmp"
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def function_1(i):
|
def function_1(i):
|
||||||
i = i + 1
|
i = i + 1
|
||||||
@@ -20,18 +23,21 @@ def function_2(i):
|
|||||||
i = i + 1
|
i = i + 1
|
||||||
return i
|
return i
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, i, BLOCK: tl.constexpr):
|
def kernel(X, i, BLOCK: tl.constexpr):
|
||||||
i = i + 1
|
i = i + 1
|
||||||
i = function_1(i)
|
i = function_1(i)
|
||||||
tl.store(X, i)
|
tl.store(X, i)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["i"])
|
@triton.jit(do_not_specialize=["i"])
|
||||||
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||||
i = i + 1
|
i = i + 1
|
||||||
i = function_1(i)
|
i = function_1(i)
|
||||||
tl.store(X, i)
|
tl.store(X, i)
|
||||||
|
|
||||||
|
|
||||||
def apply_src_change(target, old, new):
|
def apply_src_change(target, old, new):
|
||||||
delattr(kernel.fn, 'hash')
|
delattr(kernel.fn, 'hash')
|
||||||
delattr(function_1.fn, 'hash')
|
delattr(function_1.fn, 'hash')
|
||||||
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
|
|||||||
target.src = target.src.replace(new, old)
|
target.src = target.src.replace(new, old)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def test_nochange():
|
def test_nochange():
|
||||||
baseline = kernel.cache_key
|
baseline = kernel.cache_key
|
||||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||||
assert baseline == updated
|
assert baseline == updated
|
||||||
|
|
||||||
|
|
||||||
def test_toplevel_change():
|
def test_toplevel_change():
|
||||||
baseline = kernel.cache_key
|
baseline = kernel.cache_key
|
||||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||||
assert baseline != updated
|
assert baseline != updated
|
||||||
|
|
||||||
|
|
||||||
def test_nested1_change():
|
def test_nested1_change():
|
||||||
baseline = kernel.cache_key
|
baseline = kernel.cache_key
|
||||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||||
assert baseline != updated
|
assert baseline != updated
|
||||||
|
|
||||||
|
|
||||||
def reset_tmp_dir():
|
def reset_tmp_dir():
|
||||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||||
if os.path.exists(tmpdir):
|
if os.path.exists(tmpdir):
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
def test_reuse():
|
def test_reuse():
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
def inc_counter(key, binary, repr):
|
def inc_counter(key, binary, repr):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
@@ -78,6 +90,7 @@ def test_reuse():
|
|||||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||||
def test_specialize(mode):
|
def test_specialize(mode):
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
def inc_counter(key, binary, repr):
|
def inc_counter(key, binary, repr):
|
||||||
nonlocal counter
|
nonlocal counter
|
||||||
counter += 1
|
counter += 1
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
import pytest
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import triton.language as tl
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
def get_p2p_matrix():
|
def get_p2p_matrix():
|
||||||
|
@@ -1,26 +1,26 @@
|
|||||||
import ast
|
import ast
|
||||||
import builtins
|
import builtins
|
||||||
|
import dbm
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
import struct
|
|
||||||
import sys
|
|
||||||
import textwrap
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import struct
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from .tools.disasm import extract
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from filelock import FileLock
|
from .tools.disasm import extract
|
||||||
import dbm
|
|
||||||
import tempfile
|
|
||||||
from typing import Optional, Dict
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
@@ -135,7 +135,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values.append(fn.args[idx])
|
arg_values.append(fn.args[idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value)
|
self.set_value(arg_name, arg_value)
|
||||||
if inline:
|
if inline:
|
||||||
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# default: call visit_Assign
|
# default: call visit_Assign
|
||||||
return self.visit_Assign(node)
|
return self.visit_Assign(node)
|
||||||
|
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
_names = []
|
_names = []
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
@@ -404,9 +402,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||||
self.visit(pos_cond_node),\
|
self.visit(pos_cond_node),
|
||||||
self.visit(neg_cond_node),\
|
self.visit(neg_cond_node),
|
||||||
_builder=self.builder)
|
_builder=self.builder)
|
||||||
#cond_node = neg_cond_node
|
#cond_node = neg_cond_node
|
||||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||||
@@ -632,10 +630,14 @@ class Kernel:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pow2_divisor(N):
|
def pow2_divisor(N):
|
||||||
if N % 16 == 0: return 16
|
if N % 16 == 0:
|
||||||
if N % 8 == 0: return 8
|
return 16
|
||||||
if N % 4 == 0: return 4
|
if N % 8 == 0:
|
||||||
if N % 2 == 0: return 2
|
return 8
|
||||||
|
if N % 4 == 0:
|
||||||
|
return 4
|
||||||
|
if N % 2 == 0:
|
||||||
|
return 2
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
@@ -675,7 +677,7 @@ class Kernel:
|
|||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
# attributes
|
# attributes
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
@@ -768,7 +770,6 @@ class Launcher:
|
|||||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Autotuner:
|
class Autotuner:
|
||||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||||
'''
|
'''
|
||||||
@@ -788,6 +789,7 @@ class Autotuner:
|
|||||||
self.hook = lambda args: 0
|
self.hook = lambda args: 0
|
||||||
if reset_to_zero is not None:
|
if reset_to_zero is not None:
|
||||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||||
|
|
||||||
def _hook(args):
|
def _hook(args):
|
||||||
for i in self.reset_idx:
|
for i in self.reset_idx:
|
||||||
args[i].zero_()
|
args[i].zero_()
|
||||||
@@ -814,6 +816,7 @@ class Autotuner:
|
|||||||
)
|
)
|
||||||
# augment meta-parameters with tunable ones
|
# augment meta-parameters with tunable ones
|
||||||
current = dict(meta, **config.kwargs)
|
current = dict(meta, **config.kwargs)
|
||||||
|
|
||||||
def kernel_call():
|
def kernel_call():
|
||||||
if config.pre_hook:
|
if config.pre_hook:
|
||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
@@ -838,7 +841,7 @@ class Autotuner:
|
|||||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||||
bench_start = time.time()
|
bench_start = time.time()
|
||||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||||
for config in pruned_configs}
|
for config in pruned_configs}
|
||||||
bench_end = time.time()
|
bench_end = time.time()
|
||||||
self.bench_time = bench_end - bench_start
|
self.bench_time = bench_end - bench_start
|
||||||
@@ -876,7 +879,7 @@ def version_key():
|
|||||||
ptxas_version = ''
|
ptxas_version = ''
|
||||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
#########################3
|
# 3
|
||||||
|
|
||||||
|
|
||||||
class DependenciesFinder(ast.NodeVisitor):
|
class DependenciesFinder(ast.NodeVisitor):
|
||||||
@@ -917,11 +920,11 @@ class DependenciesFinder(ast.NodeVisitor):
|
|||||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
cache_hook = None
|
cache_hook = None
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
@@ -946,7 +949,6 @@ class JITFunction:
|
|||||||
# forward docs
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def cache_key(self):
|
def cache_key(self):
|
||||||
@@ -1027,6 +1029,7 @@ class Config:
|
|||||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||||
function are args.
|
function are args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
|
|||||||
def cdiv(x, y):
|
def cdiv(x, y):
|
||||||
return (x + y - 1) // y
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
def next_power_of_2(n):
|
def next_power_of_2(n):
|
||||||
"""Return the smallest power of 2 greater than or equal to n"""
|
"""Return the smallest power of 2 greater than or equal to n"""
|
||||||
n -= 1
|
n -= 1
|
||||||
@@ -1163,6 +1167,7 @@ def next_power_of_2(n):
|
|||||||
|
|
||||||
######
|
######
|
||||||
|
|
||||||
|
|
||||||
class TensorWrapper:
|
class TensorWrapper:
|
||||||
def __init__(self, base, dtype):
|
def __init__(self, base, dtype):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
import triton
|
|
||||||
from triton._C.libtriton.triton import ir
|
|
||||||
from triton._C.libtriton.triton import frontend
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
import triton
|
||||||
|
from triton._C.libtriton.triton import frontend, ir
|
||||||
|
|
||||||
|
|
||||||
# convert block/dtype to ir values
|
# convert block/dtype to ir values
|
||||||
def _to_ir(x, builder):
|
def _to_ir(x, builder):
|
||||||
@@ -111,6 +111,7 @@ class pointer_dtype:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'pointer<{self.element_ty}>'
|
return f'pointer<{self.element_ty}>'
|
||||||
|
|
||||||
|
|
||||||
# scalar types
|
# scalar types
|
||||||
int1 = dtype(ir.type.get_int1)
|
int1 = dtype(ir.type.get_int1)
|
||||||
int8 = dtype(ir.type.get_int8)
|
int8 = dtype(ir.type.get_int8)
|
||||||
@@ -489,6 +490,7 @@ def broadcast_to(input, shape, _builder=None):
|
|||||||
"""
|
"""
|
||||||
return frontend.broadcast_to(input, shape, _builder)
|
return frontend.broadcast_to(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def cat(input, other, _builder=None):
|
def cat(input, other, _builder=None):
|
||||||
"""
|
"""
|
||||||
@@ -603,6 +605,7 @@ def _add_atomic_docstr(name):
|
|||||||
|
|
||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_atomic_docstr("compare-and-swap")
|
@_add_atomic_docstr("compare-and-swap")
|
||||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||||
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
|
|||||||
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||||
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_atomic_docstr("add")
|
@_add_atomic_docstr("add")
|
||||||
def atomic_add(pointer, val, mask=None, _builder=None):
|
def atomic_add(pointer, val, mask=None, _builder=None):
|
||||||
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
|
|||||||
def umulhi(x, y, _builder=None):
|
def umulhi(x, y, _builder=None):
|
||||||
return frontend.umulhi(x, y, _builder)
|
return frontend.umulhi(x, y, _builder)
|
||||||
|
|
||||||
|
|
||||||
def _add_math_1arg_docstr(name):
|
def _add_math_1arg_docstr(name):
|
||||||
|
|
||||||
def _decorator(func):
|
def _decorator(func):
|
||||||
@@ -697,21 +702,25 @@ def _add_math_1arg_docstr(name):
|
|||||||
|
|
||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_math_1arg_docstr("exponential")
|
@_add_math_1arg_docstr("exponential")
|
||||||
def exp(x, _builder=None):
|
def exp(x, _builder=None):
|
||||||
return frontend.exp(x, _builder)
|
return frontend.exp(x, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_math_1arg_docstr("natural logarithm")
|
@_add_math_1arg_docstr("natural logarithm")
|
||||||
def log(x, _builder=None):
|
def log(x, _builder=None):
|
||||||
return frontend.log(x, _builder)
|
return frontend.log(x, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_math_1arg_docstr("cosine")
|
@_add_math_1arg_docstr("cosine")
|
||||||
def cos(x, _builder=None):
|
def cos(x, _builder=None):
|
||||||
return frontend.cos(x, _builder)
|
return frontend.cos(x, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_math_1arg_docstr("sine")
|
@_add_math_1arg_docstr("sine")
|
||||||
def sin(x, _builder=None):
|
def sin(x, _builder=None):
|
||||||
@@ -742,6 +751,7 @@ def _add_reduction_docstr(name):
|
|||||||
|
|
||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_reduction_docstr("maximum")
|
@_add_reduction_docstr("maximum")
|
||||||
def max(input, axis, _builder=None):
|
def max(input, axis, _builder=None):
|
||||||
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
|
|||||||
def sum(input, axis, _builder=None):
|
def sum(input, axis, _builder=None):
|
||||||
return frontend.sum(input, axis, _builder)
|
return frontend.sum(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_reduction_docstr("xor sum")
|
@_add_reduction_docstr("xor sum")
|
||||||
def xor_sum(input, axis, _builder=None):
|
def xor_sum(input, axis, _builder=None):
|
||||||
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
|
|||||||
def abs(x):
|
def abs(x):
|
||||||
return where(x >= 0, x, -x)
|
return where(x >= 0, x, -x)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cdiv(x, div):
|
def cdiv(x, div):
|
||||||
"""
|
"""
|
||||||
@@ -871,6 +883,7 @@ def ravel(x):
|
|||||||
"""
|
"""
|
||||||
return triton.language.reshape(x, [x.type.numel])
|
return triton.language.reshape(x, [x.type.numel])
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||||
"""
|
"""
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
import triton
|
import triton
|
||||||
from . import core as tl
|
from . import core as tl
|
||||||
|
|
||||||
|
|
||||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#from .conv import _conv, conv
|
#from .conv import _conv, conv
|
||||||
from .matmul import _matmul, matmul
|
|
||||||
from .cross_entropy import _cross_entropy, cross_entropy
|
|
||||||
from . import blocksparse
|
from . import blocksparse
|
||||||
|
from .cross_entropy import _cross_entropy, cross_entropy
|
||||||
|
from .matmul import _matmul, matmul
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import torch
|
|
||||||
|
|
||||||
# ********************************************************
|
# ********************************************************
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
@@ -11,6 +12,7 @@ import torch
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
# ********************************************************
|
# ********************************************************
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||||
})
|
})
|
||||||
@@ -37,17 +39,17 @@ def _sdd_kernel(
|
|||||||
start_am = tl.load(lut + 1)
|
start_am = tl.load(lut + 1)
|
||||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||||
offs_ak = tl.arange(0, TILE_K)
|
offs_ak = tl.arange(0, TILE_K)
|
||||||
a_ptrs = A + (off_z * stride_za \
|
a_ptrs = A + (off_z * stride_za
|
||||||
+ off_h * stride_ha \
|
+ off_h * stride_ha
|
||||||
+ offs_am[:, None] * stride_ma \
|
+ offs_am[:, None] * stride_ma
|
||||||
+ offs_ak[None, :] * stride_ak)
|
+ offs_ak[None, :] * stride_ak)
|
||||||
# initialize pointers to B
|
# initialize pointers to B
|
||||||
start_bn = tl.load(lut + 2)
|
start_bn = tl.load(lut + 2)
|
||||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||||
offs_bk = tl.arange(0, TILE_K)
|
offs_bk = tl.arange(0, TILE_K)
|
||||||
b_ptrs = B + (off_z * stride_zb \
|
b_ptrs = B + (off_z * stride_zb
|
||||||
+ off_h * stride_hb \
|
+ off_h * stride_hb
|
||||||
+ offs_bn[None, :] * stride_nb \
|
+ offs_bn[None, :] * stride_nb
|
||||||
+ offs_bk[:, None] * stride_bk)
|
+ offs_bk[:, None] * stride_bk)
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
## Inner Loop ##
|
## Inner Loop ##
|
||||||
@@ -69,12 +71,13 @@ def _sdd_kernel(
|
|||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||||
pc = C + (off_z * stride_zc \
|
pc = C + (off_z * stride_zc
|
||||||
+ block_id * stride_hc \
|
+ block_id * stride_hc
|
||||||
+ offs_cm[:, None] * stride_mc \
|
+ offs_cm[:, None] * stride_mc
|
||||||
+ offs_cn[None, :] * stride_nc)
|
+ offs_cn[None, :] * stride_nc)
|
||||||
tl.store(pc, c, mask=True)
|
tl.store(pc, c, mask=True)
|
||||||
|
|
||||||
|
|
||||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
@@ -119,6 +122,8 @@ def sdd_lut(layout, block, device):
|
|||||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _dsd_kernel(
|
def _dsd_kernel(
|
||||||
A, B, C,
|
A, B, C,
|
||||||
@@ -193,6 +198,7 @@ def _dsd_kernel(
|
|||||||
+ offs_cn[None, :] * stride_cn
|
+ offs_cn[None, :] * stride_cn
|
||||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||||
|
|
||||||
|
|
||||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
import triton.language as tl
|
|
||||||
import triton
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
def num_warps(n):
|
def num_warps(n):
|
||||||
if n < 512:
|
if n < 512:
|
||||||
@@ -161,7 +162,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
# run kernel
|
# run kernel
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
grid = [spdims[0] * spdims[1] * block, M]
|
grid = [spdims[0] * spdims[1] * block, M]
|
||||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
|
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
||||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||||
BLOCK=block,
|
BLOCK=block,
|
||||||
APPLY_SCALE=apply_scale,
|
APPLY_SCALE=apply_scale,
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def next_power_of_2(n):
|
def next_power_of_2(n):
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
from .matmul_perf_model import *
|
from .matmul_perf_model import *
|
||||||
|
|
||||||
|
|
||||||
def init_to_zero(name):
|
def init_to_zero(name):
|
||||||
return lambda nargs: nargs[name].zero_()
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
|
||||||
def get_configs_io_bound():
|
def get_configs_io_bound():
|
||||||
configs = []
|
configs = []
|
||||||
for num_stages in [2, 3, 4, 5, 6]:
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
@@ -22,6 +25,7 @@ def get_configs_io_bound():
|
|||||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
|
@@ -1,8 +1,11 @@
|
|||||||
|
import heapq
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||||
import heapq
|
|
||||||
|
|
||||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||||
''' return compute throughput in TOPS '''
|
''' return compute throughput in TOPS '''
|
||||||
@@ -11,6 +14,7 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
|||||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
def estimate_matmul_time(
|
def estimate_matmul_time(
|
||||||
# backend, device,
|
# backend, device,
|
||||||
num_warps, num_stages,
|
num_warps, num_stages,
|
||||||
@@ -73,6 +77,7 @@ def estimate_matmul_time(
|
|||||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||||
return total_time_ms
|
return total_time_ms
|
||||||
|
|
||||||
|
|
||||||
def prune_num_stages(configs):
|
def prune_num_stages(configs):
|
||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
@@ -104,7 +109,7 @@ def prune_num_stages(configs):
|
|||||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||||
|
|
||||||
# nearest stages, prefer large #stages
|
# nearest stages, prefer large #stages
|
||||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
|
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||||
|
|
||||||
for n in nearest:
|
for n in nearest:
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import triton._C.libtriton.triton as _triton
|
|
||||||
from .code_gen import OutOfResources
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
from .code_gen import OutOfResources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton._C.libtriton.cutlass as _cutlass
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
@@ -13,6 +14,7 @@ except ImportError:
|
|||||||
_cutlass = None
|
_cutlass = None
|
||||||
has_cutlass = False
|
has_cutlass = False
|
||||||
|
|
||||||
|
|
||||||
def catch_oor(kernel, pytest_handle=None):
|
def catch_oor(kernel, pytest_handle=None):
|
||||||
try:
|
try:
|
||||||
res = kernel()
|
res = kernel()
|
||||||
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
|
|||||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||||
# run function
|
# run function
|
||||||
dtype = str(a.dtype).split('.')[-1]
|
dtype = str(a.dtype).split('.')[-1]
|
||||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||||
M, N, Ka,\
|
M, N, Ka,
|
||||||
a.stride(0), a.stride(1),\
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),\
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),\
|
c.stride(0), c.stride(1),
|
||||||
dtype, dtype, dtype,
|
dtype, dtype, dtype,
|
||||||
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
||||||
|
|
||||||
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
|
|||||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||||
import numpy.testing as npt
|
import numpy.testing as npt
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
@@ -93,6 +96,7 @@ def nvsmi(attrs):
|
|||||||
ret = [int(x) for x in ret]
|
ret = [int(x) for x in ret]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
||||||
"""
|
"""
|
||||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||||
@@ -161,6 +165,7 @@ class Benchmark:
|
|||||||
"""
|
"""
|
||||||
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
x_names,
|
x_names,
|
||||||
@@ -224,9 +229,10 @@ class Mark:
|
|||||||
self.benchmarks = benchmarks
|
self.benchmarks = benchmarks
|
||||||
|
|
||||||
def _run(self, bench, save_path, show_plots, print_data):
|
def _run(self, bench, save_path, show_plots, print_data):
|
||||||
|
import os
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
|
||||||
y_mean = bench.line_names
|
y_mean = bench.line_names
|
||||||
y_min = [f'{x}-min' for x in bench.line_names]
|
y_min = [f'{x}-min' for x in bench.line_names]
|
||||||
y_max = [f'{x}-max' for x in bench.line_names]
|
y_max = [f'{x}-max' for x in bench.line_names]
|
||||||
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
|
|||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_dram_gbps(backend=None, device=None):
|
def get_dram_gbps(backend=None, device=None):
|
||||||
''' return DRAM bandwidth in GB/s '''
|
''' return DRAM bandwidth in GB/s '''
|
||||||
# assert backend == CUDA
|
# assert backend == CUDA
|
||||||
@@ -309,6 +316,7 @@ def get_dram_gbps(backend=None, device=None):
|
|||||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||||
return bw_gbps
|
return bw_gbps
|
||||||
|
|
||||||
|
|
||||||
def get_max_tensorcore_tflops(backend, device):
|
def get_max_tensorcore_tflops(backend, device):
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||||
|
@@ -21,8 +21,8 @@
|
|||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import re
|
import re
|
||||||
|
import subprocess
|
||||||
|
|
||||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||||
|
@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
|||||||
# Compute Kernel
|
# Compute Kernel
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
from triton.language.core import constexpr
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
@@ -16,6 +16,8 @@ You will learn about:
|
|||||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||||
|
|
||||||
|
import triton.language as tl
|
||||||
|
import triton
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -59,9 +61,6 @@ def naive_softmax(x):
|
|||||||
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||||
# memory operations properly if we want to handle any possible input shapes:
|
# memory operations properly if we want to handle any possible input shapes:
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def softmax_kernel(
|
def softmax_kernel(
|
||||||
|
@@ -141,6 +141,7 @@ You will specifically learn about:
|
|||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -152,6 +153,7 @@ import triton.language as tl
|
|||||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||||
# provided configs
|
# provided configs
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||||
|
@@ -30,9 +30,11 @@ whose state is generally composed of a bit mask tensor of the same shape as the
|
|||||||
|
|
||||||
import tabulate
|
import tabulate
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _dropout(
|
def _dropout(
|
||||||
x_ptr, # pointer to the input
|
x_ptr, # pointer to the input
|
||||||
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
|
|||||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Input tensor
|
# Input tensor
|
||||||
x = torch.randn(size=(10,)).cuda()
|
x = torch.randn(size=(10,)).cuda()
|
||||||
# Dropout mask
|
# Dropout mask
|
||||||
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
|
|||||||
#
|
#
|
||||||
# Let's put it all together.
|
# Let's put it all together.
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _seeded_dropout(
|
def _seeded_dropout(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
|
@@ -4,8 +4,10 @@ Layer Normalization
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
# Forward Pass
|
# Forward Pass
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -97,6 +99,8 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
|||||||
tl.atomic_xchg(Lock, 0)
|
tl.atomic_xchg(Lock, 0)
|
||||||
|
|
||||||
# Backward pass (total DW + total DB)
|
# Backward pass (total DW + total DB)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
@@ -116,6 +120,7 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
|||||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.autograd.Function):
|
class LayerNorm(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
|||||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['N'],
|
x_names=['N'],
|
||||||
@@ -248,4 +254,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
|
|||||||
grad_to_none=[x], rep=500)
|
grad_to_none=[x], rep=500)
|
||||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||||
|
|
||||||
|
|
||||||
bench_layer_norm.run(save_path='.', print_data=True)
|
bench_layer_norm.run(save_path='.', print_data=True)
|
||||||
|
Reference in New Issue
Block a user