[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
@@ -17,8 +18,8 @@ square_confs = [
|
||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for AT in [False] for BT in [False] \
|
||||
)
|
||||
for AT in [False] for BT in [False]
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
]
|
||||
|
||||
@@ -27,7 +28,7 @@ square_confs = [
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
@@ -45,8 +46,8 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode] * 1e-12
|
||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||
@@ -66,7 +67,7 @@ square_confs = [
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{layout_mode}-square',
|
||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
)
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
@@ -11,7 +12,7 @@ confs = [
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{mode}-2048',
|
||||
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
)
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
@@ -24,7 +25,7 @@ def bench_op(M, N, dtype, mode, provider):
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
gbps = lambda ms: num_gb / ms * 1e3
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import triton
|
||||
import torch
|
||||
import os
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
@@ -36,8 +36,8 @@ transformer_confs = [
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [12288]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
) for NK in [12288]
|
||||
for i, x in enumerate(["N", "K"])
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
@@ -46,8 +46,10 @@ transformer_confs = [
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||
if provider == "cublas":
|
||||
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
|
@@ -1,20 +1,19 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import sysconfig
|
||||
import platform
|
||||
import subprocess
|
||||
import distutils
|
||||
import glob
|
||||
import tempfile
|
||||
import shutil
|
||||
from distutils.version import LooseVersion
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import urllib.request
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
@@ -32,7 +31,7 @@ def get_llvm():
|
||||
if not os.path.exists(llvm_library_dir):
|
||||
try:
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
|
@@ -1,14 +1,18 @@
|
||||
from numpy import record
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from numpy import record
|
||||
|
||||
import triton
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
#######################
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
@@ -46,6 +50,8 @@ matmul_data = {
|
||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||
# (256 , 256 , 32768) : {'v100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||
def test_matmul(M, N, K):
|
||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||
@@ -61,10 +67,11 @@ def test_matmul(M, N, K):
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
||||
#######################
|
||||
# Element-Wise
|
||||
#######################
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
@@ -89,6 +96,7 @@ elementwise_data = {
|
||||
1024 * 65536: {'v100': 0.939},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||
def test_elementwise(N):
|
||||
ref_gpu_util = elementwise_data[N]['v100']
|
||||
@@ -105,4 +113,3 @@ def test_elementwise(N):
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
@@ -206,6 +209,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
# ---------------
|
||||
# test binary ops
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['+', '-', '*', '/', '%']
|
||||
@@ -298,16 +303,18 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
# test compare ops
|
||||
# ---------------
|
||||
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
||||
# real
|
||||
[
|
||||
(dtype_x, dtype_y, op, 'real', 'real') \
|
||||
for op in ops \
|
||||
for dtype_x in dtypes \
|
||||
(dtype_x, dtype_y, op, 'real', 'real')
|
||||
for op in ops
|
||||
for dtype_x in dtypes
|
||||
for dtype_y in dtypes
|
||||
] + \
|
||||
] +
|
||||
# NaNs
|
||||
[('float32', 'float32', op, mode_x, mode_y) \
|
||||
[('float32', 'float32', op, mode_x, mode_y)
|
||||
for op in ops
|
||||
for mode_x, mode_y in [('nan', 'real'),
|
||||
('real', 'nan'),
|
||||
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
@@ -558,9 +566,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape",
|
||||
[(dtype, shape) \
|
||||
for dtype in dtypes\
|
||||
[(dtype, shape)
|
||||
for dtype in dtypes
|
||||
for shape in [128, 512]])
|
||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
@@ -608,10 +618,12 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm) \
|
||||
for dtype in ['float32']\
|
||||
for shape in [(128, 128)]\
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['float32']
|
||||
for shape in [(128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
|
||||
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
# triton kernel
|
||||
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
@@ -723,10 +737,12 @@ def test_dot_without_load():
|
||||
# test arange
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(z, BLOCK: tl.constexpr,
|
||||
START: tl.constexpr, END: tl.constexpr):
|
||||
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
|
||||
# ---------------
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M = 32
|
||||
@@ -788,6 +806,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
@@ -831,10 +850,13 @@ def test_load_cache_modifier(cache):
|
||||
# test default
|
||||
# ---------------
|
||||
# TODO: can't be local to test_default
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _impl(value=10):
|
||||
return value
|
||||
|
||||
|
||||
def test_default():
|
||||
value = 5
|
||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
@@ -852,6 +874,8 @@ def test_default():
|
||||
# ---------------
|
||||
# test noop
|
||||
# ----------------
|
||||
|
||||
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(x):
|
||||
|
@@ -1,16 +1,17 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from numpy.random import Philox
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
@@ -132,8 +136,10 @@ def test_randint(size, seed, device='cuda'):
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@@ -151,8 +157,10 @@ def test_rand(size, seed, device='cuda'):
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
|
||||
# broadcast at_mask to the same shape as rx
|
||||
if is_causal: at_mask = torch.tril(at_mask)
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float("-inf")
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
|
@@ -1,12 +1,14 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
|
@@ -1,13 +1,16 @@
|
||||
import torch
|
||||
import triton
|
||||
from triton.code_gen import JITFunction
|
||||
import triton.language as tl
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_1(i):
|
||||
i = i + 1
|
||||
@@ -20,18 +23,21 @@ def function_2(i):
|
||||
i = i + 1
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["i"])
|
||||
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
def apply_src_change(target, old, new):
|
||||
delattr(kernel.fn, 'hash')
|
||||
delattr(function_1.fn, 'hash')
|
||||
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
|
||||
target.src = target.src.replace(new, old)
|
||||
return ret
|
||||
|
||||
|
||||
def test_nochange():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||
assert baseline == updated
|
||||
|
||||
|
||||
def test_toplevel_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def test_nested1_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
@@ -78,6 +90,7 @@ def test_reuse():
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
|
||||
def inc_counter(key, binary, repr):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
import subprocess
|
||||
import triton.language as tl
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
|
@@ -1,26 +1,26 @@
|
||||
import ast
|
||||
import builtins
|
||||
import dbm
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
import sys
|
||||
import textwrap
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import struct
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from .tools.disasm import extract
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
import tempfile
|
||||
from typing import Optional, Dict
|
||||
import time
|
||||
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -135,7 +135,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
|
||||
def visit_Assign(self, node):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
@@ -404,9 +402,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||
self.visit(pos_cond_node),
|
||||
self.visit(neg_cond_node),
|
||||
_builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
@@ -632,10 +630,14 @@ class Kernel:
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
if N % 16 == 0:
|
||||
return 16
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def __init__(self, fn):
|
||||
@@ -675,7 +677,7 @@ class Kernel:
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
@@ -768,7 +770,6 @@ class Launcher:
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
@@ -788,6 +789,7 @@ class Autotuner:
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
@@ -814,6 +816,7 @@ class Autotuner:
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
@@ -838,7 +841,7 @@ class Autotuner:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
@@ -876,7 +879,7 @@ def version_key():
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
#########################3
|
||||
# 3
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
@@ -917,11 +920,11 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
@@ -946,7 +949,6 @@ class JITFunction:
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
@@ -1027,6 +1029,7 @@ class Config:
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
@@ -1163,6 +1167,7 @@ def next_power_of_2(n):
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import triton
|
||||
from triton._C.libtriton.triton import ir
|
||||
from triton._C.libtriton.triton import frontend
|
||||
from functools import wraps
|
||||
|
||||
import triton
|
||||
from triton._C.libtriton.triton import frontend, ir
|
||||
|
||||
|
||||
# convert block/dtype to ir values
|
||||
def _to_ir(x, builder):
|
||||
@@ -111,6 +111,7 @@ class pointer_dtype:
|
||||
def __str__(self):
|
||||
return f'pointer<{self.element_ty}>'
|
||||
|
||||
|
||||
# scalar types
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
int8 = dtype(ir.type.get_int8)
|
||||
@@ -489,6 +490,7 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
"""
|
||||
@@ -603,6 +605,7 @@ def _add_atomic_docstr(name):
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||
return frontend.atomic_xchg(pointer, val, mask, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, _builder=None):
|
||||
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
|
||||
def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -697,21 +702,25 @@ def _add_math_1arg_docstr(name):
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("exponential")
|
||||
def exp(x, _builder=None):
|
||||
return frontend.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
return frontend.log(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("cosine")
|
||||
def cos(x, _builder=None):
|
||||
return frontend.cos(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("sine")
|
||||
def sin(x, _builder=None):
|
||||
@@ -742,6 +751,7 @@ def _add_reduction_docstr(name):
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis, _builder=None):
|
||||
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
|
||||
def sum(input, axis, _builder=None):
|
||||
return frontend.sum(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis, _builder=None):
|
||||
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
|
||||
def abs(x):
|
||||
return where(x >= 0, x, -x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
@@ -871,6 +883,7 @@ def ravel(x):
|
||||
"""
|
||||
return triton.language.reshape(x, [x.type.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#from .conv import _conv, conv
|
||||
from .matmul import _matmul, matmul
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from .matmul import _matmul, matmul
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
@@ -11,6 +12,7 @@ import torch
|
||||
# --------------------------------------------------------
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@@ -37,17 +39,17 @@ def _sdd_kernel(
|
||||
start_am = tl.load(lut + 1)
|
||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
a_ptrs = A + (off_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
a_ptrs = A + (off_z * stride_za
|
||||
+ off_h * stride_ha
|
||||
+ offs_am[:, None] * stride_ma
|
||||
+ offs_ak[None, :] * stride_ak)
|
||||
# initialize pointers to B
|
||||
start_bn = tl.load(lut + 2)
|
||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
b_ptrs = B + (off_z * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_nb \
|
||||
b_ptrs = B + (off_z * stride_zb
|
||||
+ off_h * stride_hb
|
||||
+ offs_bn[None, :] * stride_nb
|
||||
+ offs_bk[:, None] * stride_bk)
|
||||
## ---------------- ##
|
||||
## Inner Loop ##
|
||||
@@ -69,12 +71,13 @@ def _sdd_kernel(
|
||||
## ---------------- ##
|
||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||
pc = C + (off_z * stride_zc \
|
||||
+ block_id * stride_hc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
pc = C + (off_z * stride_zc
|
||||
+ block_id * stride_hc
|
||||
+ offs_cm[:, None] * stride_mc
|
||||
+ offs_cn[None, :] * stride_nc)
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
@@ -119,6 +122,8 @@ def sdd_lut(layout, block, device):
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
@@ -193,6 +198,7 @@ def _dsd_kernel(
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n < 512:
|
||||
@@ -161,7 +162,7 @@ class _softmax(torch.autograd.Function):
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = [spdims[0] * spdims[1] * block, M]
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
BLOCK=block,
|
||||
APPLY_SCALE=apply_scale,
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
|
@@ -1,11 +1,14 @@
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import *
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
@@ -22,6 +25,7 @@ def get_configs_io_bound():
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||
})
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
import heapq
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||
''' return compute throughput in TOPS '''
|
||||
@@ -11,6 +14,7 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
@@ -73,6 +77,7 @@ def estimate_matmul_time(
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
|
||||
def prune_num_stages(configs):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
@@ -104,7 +109,7 @@ def prune_num_stages(configs):
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import os
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
@@ -13,6 +14,7 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
res = kernel()
|
||||
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
|
||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||
# run function
|
||||
dtype = str(a.dtype).split('.')[-1]
|
||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
|
||||
M, N, Ka,\
|
||||
a.stride(0), a.stride(1),\
|
||||
b.stride(0), b.stride(1),\
|
||||
c.stride(0), c.stride(1),\
|
||||
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
M, N, Ka,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dtype, dtype, dtype,
|
||||
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
|
||||
|
||||
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||
import numpy.testing as npt
|
||||
if isinstance(x, torch.Tensor):
|
||||
@@ -93,6 +96,7 @@ def nvsmi(attrs):
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
@@ -161,6 +165,7 @@ class Benchmark:
|
||||
"""
|
||||
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
@@ -224,9 +229,10 @@ class Mark:
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, save_path, show_plots, print_data):
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
y_mean = bench.line_names
|
||||
y_min = [f'{x}-min' for x in bench.line_names]
|
||||
y_max = [f'{x}-max' for x in bench.line_names]
|
||||
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_dram_gbps(backend=None, device=None):
|
||||
''' return DRAM bandwidth in GB/s '''
|
||||
# assert backend == CUDA
|
||||
@@ -309,6 +316,7 @@ def get_dram_gbps(backend=None, device=None):
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(backend, device):
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
|
@@ -21,8 +21,8 @@
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||
|
@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
|
||||
from triton.language.core import constexpr
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
@@ -16,6 +16,8 @@ You will learn about:
|
||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
import triton.language as tl
|
||||
import triton
|
||||
import torch
|
||||
|
||||
|
||||
@@ -59,9 +61,6 @@ def naive_softmax(x):
|
||||
# power-of-two number of elements, so we need to internally "pad" each row and guard the
|
||||
# memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
|
@@ -141,6 +141,7 @@ You will specifically learn about:
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -152,6 +153,7 @@ import triton.language as tl
|
||||
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
|
@@ -30,9 +30,11 @@ whose state is generally composed of a bit mask tensor of the same shape as the
|
||||
|
||||
import tabulate
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dropout(
|
||||
x_ptr, # pointer to the input
|
||||
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
|
||||
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
# Input tensor
|
||||
x = torch.randn(size=(10,)).cuda()
|
||||
# Dropout mask
|
||||
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
|
||||
#
|
||||
# Let's put it all together.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_dropout(
|
||||
x_ptr,
|
||||
|
@@ -4,8 +4,10 @@ Layer Normalization
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Forward Pass
|
||||
@triton.jit
|
||||
@@ -97,6 +99,8 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
@@ -116,6 +120,7 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
@@ -248,4 +254,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user