[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,4 +1,5 @@
import torch
import triton
# -------------------------------
@@ -17,8 +18,8 @@ square_confs = [
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
@@ -45,8 +46,8 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -66,7 +67,7 @@ square_confs = [
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\
)
for layout_mode in ['dense', 'tril']
]

View File

@@ -1,4 +1,5 @@
import torch
import triton
confs = [
@@ -11,7 +12,7 @@ confs = [
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\
)
for mode in ['forward', 'backward']
]
@@ -24,7 +25,7 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))

View File

@@ -1,6 +1,6 @@
import triton
import torch
import os
import triton
def rounded_linspace(low, high, steps, div):
@@ -36,8 +36,8 @@ transformer_confs = [
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\
for i, x in enumerate(["N", "K"])\
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]
@@ -46,8 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
@@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except:
except Exception:
return None
return None

View File

@@ -1,7 +1,8 @@
import argparse
import sys
import os
import inspect
import os
import sys
import triton

View File

@@ -1,20 +1,19 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn
import urllib.request
import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm():
# tries to find system LLVM
@@ -32,7 +31,7 @@ def get_llvm():
if not os.path.exists(llvm_library_dir):
try:
shutil.rmtree(os.path.join(dir, name))
except:
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')

View File

@@ -1,14 +1,18 @@
from numpy import record
import torch
import triton
import triton.language as tl
import subprocess
import sys
import pytest
import torch
from numpy import record
import triton
#######################
# Utilities
#######################
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
@@ -46,6 +50,8 @@ matmul_data = {
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
}
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K):
ref_gpu_util = matmul_data[(M, N, K)]['v100']
@@ -61,10 +67,11 @@ def test_matmul(M, N, K):
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Element-Wise
#######################
import triton.language as tl
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
@@ -89,6 +96,7 @@ elementwise_data = {
1024 * 65536: {'v100': 0.939},
}
@pytest.mark.parametrize('N', elementwise_data.keys())
def test_elementwise(N):
ref_gpu_util = elementwise_data[N]['v100']
@@ -105,4 +113,3 @@ def test_elementwise(N):
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -206,6 +209,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
@@ -298,16 +303,18 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] + \
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y) \
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
@@ -558,9 +566,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \
for dtype in dtypes\
[(dtype, shape)
for dtype in dtypes
for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'):
@@ -608,10 +618,12 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
[(dtype, shape, perm)
for dtype in ['float32']
for shape in [(128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
# triton kernel
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out):
@@ -723,10 +737,12 @@ def test_dot_without_load():
# test arange
# ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# ---------------
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32
@@ -788,6 +806,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda')
@@ -831,10 +850,13 @@ def test_load_cache_modifier(cache):
# test default
# ---------------
# TODO: can't be local to test_default
@triton.jit
def _impl(value=10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -852,6 +874,8 @@ def test_default():
# ---------------
# test noop
# ----------------
def test_noop(device='cuda'):
@triton.jit
def kernel(x):

View File

@@ -1,16 +1,17 @@
import torch
import triton
import triton.language as tl
import numpy as np
import pytest
import scipy.stats
import numpy as np
import torch
from numpy.random import Philox
import triton
import triton.language as tl
#####################################
## Reference Philox Implementation
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
#####################################
## Unit Tests
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,8 +136,10 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@@ -151,8 +157,10 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask)
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :]

View File

@@ -1,12 +1,14 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)

View File

@@ -1,8 +1,10 @@
import pytest
import itertools
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1
@@ -78,6 +90,7 @@ def test_reuse():
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix():

View File

@@ -1,26 +1,26 @@
import ast
import builtins
import dbm
import functools
import inspect
import struct
import sys
import textwrap
import hashlib
import inspect
import os
import pickle
import struct
import subprocess
import os
import sys
import tempfile
import textwrap
import time
import warnings
from .tools.disasm import extract
from typing import Dict, Optional
import torch
from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from filelock import FileLock
import dbm
import tempfile
from typing import Optional, Dict
import time
from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor):
@@ -135,7 +135,6 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(fn.args[idx])
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
@@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor):
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
@@ -404,9 +402,9 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
self.visit(pos_cond_node),
self.visit(neg_cond_node),
_builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
@@ -632,10 +630,14 @@ class Kernel:
@staticmethod
def pow2_divisor(N):
if N % 16 == 0: return 16
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
if N % 16 == 0:
return 16
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
def __init__(self, fn):
@@ -675,7 +677,7 @@ class Kernel:
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args)
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
@@ -768,7 +770,6 @@ class Launcher:
return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
@@ -788,6 +789,7 @@ class Autotuner:
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
@@ -814,6 +816,7 @@ class Autotuner:
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
@@ -838,7 +841,7 @@ class Autotuner:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
@@ -876,7 +879,7 @@ def version_key():
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
#########################3
# 3
class DependenciesFinder(ast.NodeVisitor):
@@ -917,11 +920,11 @@ class DependenciesFinder(ast.NodeVisitor):
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None):
# information of wrapped function
self.fn = fn
@@ -946,7 +949,6 @@ class JITFunction:
# forward docs
self.__doc__ = fn.__doc__
@property
@functools.lru_cache()
def cache_key(self):
@@ -1027,6 +1029,7 @@ class Config:
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
@@ -1150,6 +1153,7 @@ def jit(*args, **kwargs):
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
@@ -1163,6 +1167,7 @@ def next_power_of_2(n):
######
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype

View File

@@ -1,8 +1,8 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
from functools import wraps
import triton
from triton._C.libtriton.triton import frontend, ir
# convert block/dtype to ir values
def _to_ir(x, builder):
@@ -111,6 +111,7 @@ class pointer_dtype:
def __str__(self):
return f'pointer<{self.element_ty}>'
# scalar types
int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8)
@@ -489,6 +490,7 @@ def broadcast_to(input, shape, _builder=None):
"""
return frontend.broadcast_to(input, shape, _builder)
@builtin
def cat(input, other, _builder=None):
"""
@@ -603,6 +605,7 @@ def _add_atomic_docstr(name):
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
@@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_xchg(pointer, val, mask=None, _builder=None):
return frontend.atomic_xchg(pointer, val, mask, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None):
@@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None):
def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder)
def _add_math_1arg_docstr(name):
def _decorator(func):
@@ -697,21 +702,25 @@ def _add_math_1arg_docstr(name):
return _decorator
@builtin
@_add_math_1arg_docstr("exponential")
def exp(x, _builder=None):
return frontend.exp(x, _builder)
@builtin
@_add_math_1arg_docstr("natural logarithm")
def log(x, _builder=None):
return frontend.log(x, _builder)
@builtin
@_add_math_1arg_docstr("cosine")
def cos(x, _builder=None):
return frontend.cos(x, _builder)
@builtin
@_add_math_1arg_docstr("sine")
def sin(x, _builder=None):
@@ -742,6 +751,7 @@ def _add_reduction_docstr(name):
return _decorator
@builtin
@_add_reduction_docstr("maximum")
def max(input, axis, _builder=None):
@@ -759,6 +769,7 @@ def min(input, axis, _builder=None):
def sum(input, axis, _builder=None):
return frontend.sum(input, axis, _builder)
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None):
@@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None):
def abs(x):
return where(x >= 0, x, -x)
@triton.jit
def cdiv(x, div):
"""
@@ -871,6 +883,7 @@ def ravel(x):
"""
return triton.language.reshape(x, [x.type.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""

View File

@@ -1,7 +1,6 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53

View File

@@ -1,4 +1,4 @@
#from .conv import _conv, conv
from .matmul import _matmul, matmul
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul

View File

@@ -1,6 +1,7 @@
import torch
import triton
import triton.language as tl
import torch
# ********************************************************
# --------------------------------------------------------
@@ -11,6 +12,7 @@ import torch
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@@ -37,17 +39,17 @@ def _sdd_kernel(
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + (off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
a_ptrs = A + (off_z * stride_za
+ off_h * stride_ha
+ offs_am[:, None] * stride_ma
+ offs_ak[None, :] * stride_ak)
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + (off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
b_ptrs = B + (off_z * stride_zb
+ off_h * stride_hb
+ offs_bn[None, :] * stride_nb
+ offs_bk[:, None] * stride_bk)
## ---------------- ##
## Inner Loop ##
@@ -69,12 +71,13 @@ def _sdd_kernel(
## ---------------- ##
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + (off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
pc = C + (off_z * stride_zc
+ block_id * stride_hc
+ offs_cm[:, None] * stride_mc
+ offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
@@ -119,6 +122,8 @@ def sdd_lut(layout, block, device):
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
@@ -193,6 +198,7 @@ def _dsd_kernel(
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()

View File

@@ -1,7 +1,8 @@
import triton.language as tl
import triton
import torch
import triton
import triton.language as tl
def num_warps(n):
if n < 512:
@@ -161,7 +162,7 @@ class _softmax(torch.autograd.Function):
# run kernel
M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK=block,
APPLY_SCALE=apply_scale,

View File

@@ -1,7 +1,9 @@
import os
import torch
import triton
import triton.language as tl
import torch
def next_power_of_2(n):

View File

@@ -1,11 +1,14 @@
import torch
import triton.language as tl
import triton
import triton.language as tl
from .matmul_perf_model import *
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
@@ -22,6 +25,7 @@ def get_configs_io_bound():
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
})

View File

@@ -1,8 +1,11 @@
import heapq
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS '''
@@ -11,6 +14,7 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
@@ -73,6 +77,7 @@ def estimate_matmul_time(
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
@@ -104,7 +109,7 @@ def prune_num_stages(configs):
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:

View File

@@ -1,10 +1,11 @@
import torch
import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess
import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass
@@ -13,6 +14,7 @@ except ImportError:
_cutlass = None
has_cutlass = False
def catch_oor(kernel, pytest_handle=None):
try:
res = kernel()
@@ -42,11 +44,11 @@ def cutlass_matmul(a, b):
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function
dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka,\
a.stride(0), a.stride(1),\
b.stride(0), b.stride(1),\
c.stride(0), c.stride(1),\
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, Ka,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
@@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
@@ -93,6 +96,7 @@ def nvsmi(attrs):
ret = [int(x) for x in ret]
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
@@ -161,6 +165,7 @@ class Benchmark:
"""
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
"""
def __init__(
self,
x_names,
@@ -224,9 +229,10 @@ class Mark:
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots, print_data):
import os
import matplotlib.pyplot as plt
import pandas as pd
import os
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
@@ -297,6 +303,7 @@ def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA
@@ -309,6 +316,7 @@ def get_dram_gbps(backend=None, device=None):
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz

View File

@@ -21,8 +21,8 @@
# SOFTWARE.
import argparse
import subprocess
import re
import subprocess
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')

View File

@@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel
# --------------------------
from triton.language.core import constexpr
import torch
import triton
import triton.language as tl

View File

@@ -16,6 +16,8 @@ You will learn about:
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import triton.language as tl
import triton
import torch
@@ -59,9 +61,6 @@ def naive_softmax(x):
# power-of-two number of elements, so we need to internally "pad" each row and guard the
# memory operations properly if we want to handle any possible input shapes:
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(

View File

@@ -141,6 +141,7 @@ You will specifically learn about:
#
import torch
import triton
import triton.language as tl
@@ -152,6 +153,7 @@ import triton.language as tl
# - An autotuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),

View File

@@ -30,9 +30,11 @@ whose state is generally composed of a bit mask tensor of the same shape as the
import tabulate
import torch
import triton
import triton.language as tl
@triton.jit
def _dropout(
x_ptr, # pointer to the input
@@ -64,6 +66,7 @@ def dropout(x, x_keep, p):
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output
# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
@@ -97,6 +100,7 @@ print(tabulate.tabulate([
#
# Let's put it all together.
@triton.jit
def _seeded_dropout(
x_ptr,

View File

@@ -4,8 +4,10 @@ Layer Normalization
"""
import torch
import triton.language as tl
import triton
import triton.language as tl
# Forward Pass
@triton.jit
@@ -97,6 +99,8 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
pid = tl.program_id(0)
@@ -116,6 +120,7 @@ def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
@@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
@@ -248,4 +254,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
bench_layer_norm.run(save_path='.', print_data=True)