[PYTHON] Added automated benchmark script (#63)

This adds a bench functionality to the setup.py that can be used to run the benchmark suite and generates a bunch of csv files (and optionally plots)

python setup.py bench
python setup.py bench --with-plots
python setup.py bench --filter=cross_entropy
This commit is contained in:
Philippe Tillet
2021-02-08 12:16:41 -08:00
committed by Philippe Tillet
parent 66c94f21d7
commit 5e3c7f5a60
12 changed files with 472 additions and 339 deletions

View File

@@ -0,0 +1,87 @@
import torch
import triton
# -------------------------------
# Matrix Multiplication
# -------------------------------
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
ylabel = 'TFLOPS',
loglog = False,
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\
for AT in [False] for BT in [False] \
for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense']
]
@triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
layout = make_layout(H, shape[0] // block, shape[1] // block)
# creat inputs
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
# create op
if provider == 'triton':
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
# inputs
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12
triton_tflops = num_flops / ms * 1e3
return triton_tflops
# -------------------------------
# Softmax
# -------------------------------
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
loglog = False,
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\
for layout_mode in ['dense', 'tril']
]
@triton.testing.perf_report(square_confs)
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
layout = make_layout(H, M // block, N // block)
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
if provider == 'triton':
a = triton.testing.sparsify_tensor(a, layout, block)
op = triton.ops.blocksparse.softmax(layout, block)
ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
gbps = (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
return gbps

View File

@@ -0,0 +1,37 @@
import torch
import triton
confs = [
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
y_name = 'provider',
y_vals = ['triton', 'torch'],
y_lines = ['Triton', 'Torch'],
ylabel = 'GBPS',
loglog = False,
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\
for mode in ['forward', 'backward']
]
@triton.testing.perf_report(confs)
def bench_op(M, N, dtype, mode, provider):
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
y = op(x, idx)
dy = torch.randn_like(y)
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True))
return num_gb / ms * 1e3
if __name__ == '__main__':
bench_op.run('tmp', False)

View File

@@ -0,0 +1,59 @@
import triton
import torch
# square benchmarks
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'provider',
y_vals = ['torch', 'triton', 'cutlass'],
y_lines = ['Torch', 'Triton', 'CUTLASS'],
ylabel = 'TFLOPS',
loglog = False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
)\
for AT in [False, True] for BT in [False, True]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
num_flops = 2 * M * N * K
if provider == 'torch':
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
if provider == 'triton':
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row'
layout_b = 'column' if BT else 'row'
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv')
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
return cutlass_tflops
return None
if __name__ == '__main__':
bench_op.run()

View File

@@ -14,31 +14,27 @@ from setuptools.command.test import test as TestCommand
import distutils.spawn
import torch
def find_llvm():
versions = ['-10', '-9.0', '-9', '-90', '-8.0', '-8', '-80', '']
versions = ['-10', '-10.0', '']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return paths[0]
return paths[0]
config = distutils.spawn.find_executable('llvm-config')
instructions = 'Please install llvm-{8, 9, 10}-dev'
instructions = 'Please install llvm-10-dev'
if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions)
version = os.popen('{config} --version'.format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
@@ -63,16 +59,18 @@ class CMakeBuild(build_ext):
torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TUTORIALS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()]
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TUTORIALS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()
]
# configuration
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
@@ -90,10 +88,54 @@ class CMakeBuild(build_ext):
env = os.environ.copy()
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
class BenchCommand(distutils.cmd.Command):
description = 'run benchmark suite'
user_options = [
('result-dir=', None, 'path to output benchmark results'),\
('with-plots', None, 'plot benchmark results'),\
('filter=' , None, 'filter benchmarks by name')
]
def initialize_options(self):
self.result_dir = 'results'
self.filter = ''
self.with_plots = False
def finalize_options(self):
if not os.path.exists(self.result_dir):
os.makedirs(self.result_dir)
def run(self):
import sys
import inspect
import triton
bench_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bench')
sys.path.append(bench_dir)
for mod in os.listdir(bench_dir):
# skip non python files
if not mod.endswith('.py'):
continue
# skip file not in provided filter
if self.filter and self.filter not in mod:
continue
# skip files that don't start with 'bench_'
if not mod.startswith('bench_'):
continue
print(f'running {mod}...')
mod = __import__(os.path.splitext(mod)[0])
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
for name, bench in benchmarks:
result_dir = os.path.join(self.result_dir, mod.__name__.replace('bench_', ''))
if len(benchmarks) > 1:
result_dir = os.path.join(result_dir, name.replace('bench_', ''))
if not os.path.exists(result_dir):
os.makedirs(result_dir)
bench.run(result_dir, self.with_plots)
setup(
name='triton',
@@ -104,21 +146,20 @@ setup(
long_description='',
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch'],
package_data={'triton/ops': ['*.c'],
'triton/ops/blocksparse': ['*.c']},
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild),
cmdclass={'build_ext': CMakeBuild, 'bench': BenchCommand},
zip_safe=False,
# for PyPI
keyword=['Compiler', 'Deep Learning'],
keywords=['Compiler', 'Deep Learning'],
url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license
'Programming Language :: Python :: 3.6',
],
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license
'Programming Language :: Python :: 3.6',
],
)

View File

@@ -0,0 +1,78 @@
import torch
import triton
import pytest
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
for at in [False, True]\
for bt in [False, True]\
for block in [16, 32, 64]
]
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
rc = op(ra, rb)
# torch result
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
# compare
assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize("BLOCK, WIDTH",
[
(block, width) for block in [32]\
for width in [256, 576, 1024, 1792]
]
)
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
at_mask = torch.randint(low=0, high=2, size=(N, N), \
dtype=torch.bool, requires_grad=False, device='cuda')
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
dtype=DTYPE, requires_grad=False, device='cuda')
kp_mask[kp_mask == 1.] = float('-inf')
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
ty = op(tx,
scale=scale,
key_padding_mask=kp_mask,
key_padding_mask_mode='add',
attn_mask=at_mask.to(DTYPE),
attn_mask_mode='mul')
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float('-inf'))
if at_mask is not None:
# broadcast at_mask to the same shape as rx
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float('-inf')
if kp_mask is not None:
rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx * scale, -1)
ry = torch.softmax(rx * scale, -1)
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
# compare
assert triton.testing.allclose(ry, ty)

View File

@@ -0,0 +1,63 @@
import pytest
import itertools
import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
torch.manual_seed(0)
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK * TZ
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c)

View File

@@ -1,160 +0,0 @@
import itertools
import torch
import triton as tt
import pytest
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
return ret
def mask_tensor(x, mask, block, value = 0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
## -----------------------------------------------------------------------------
## Unit Tests
## -----------------------------------------------------------------------------
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
for at in [False, True]\
for bt in [False, True]\
for block in [16, 32, 64]
]
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
# triton result
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
rc = op(ra, rb)
# torch result
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BLOCK, WIDTH",
[
(block, width) for block in [32]\
for width in [256, 576, 1024, 2048, 4096]
]
)
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M//BLOCK, N//BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
at_mask = torch.randint(low=0, high=2, size=(N, N), \
dtype=torch.bool, requires_grad=False, device='cuda')
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
dtype=DTYPE, requires_grad=False, device='cuda')
kp_mask[kp_mask==1.] = float('-inf')
# triton result
op = tt.ops.blocksparse.softmax(layout, BLOCK)
tx = sparsify_tensor(x, layout, BLOCK)
ty = op(tx, scale=scale)
# torch result
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
# if at_mask is not None:
# # broadcast at_mask to the same shape as rx
# M = at_mask[None, None, :, :] + torch.zeros_like(rx)
# rx[M == 0] = float('-inf')
# if kp_mask is not None:
# rx += kp_mask[:, None, None, :]
ry = torch.softmax(rx*scale, -1)
ry = sparsify_tensor(ry, layout, BLOCK)
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
## -----------------------------------------------------------------------------
## Performance Tests
## -----------------------------------------------------------------------------
def do_bench(fn, warmup = 10, rep = 50):
import torch as th
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
# create layout
M, N, K = N, N, N
shape = {'sdd': (M, N),
'dsd': (K, M) if TRANS_A else (M, K),
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
# create op
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
# inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
triton_tflops = num_flops / ms * 1e3
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
layout = make_layout(H, N//BLOCK, N//BLOCK)
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK)
op = tt.ops.blocksparse.softmax(layout, BLOCK)
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
nbytes = 2 * a.numel() * a.element_size()
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
print(triton_gbyps)

View File

@@ -1,151 +0,0 @@
import pytest
import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul._kernels = dict()
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK*TZ
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = th.matmul(a, b)
tt_c = tt.ops.matmul(a, b)
rtol, atol = {th.float32: (1e-4, 1e-5),
th.float16: (1e-2, 1e-3)}[DTYPE]
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def time_all(fn, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog=True, plot_name='', **kwargs):
import matplotlib.pyplot as plt
import pandas as pd
df = pd.DataFrame(columns = [x_names[0]] + y_lines)
for x in x_vals:
x_args = {x_name: x for x_name in x_names}
row = [fn(**x_args, **{y_name: y}, **kwargs) for y in y_vals]
df.loc[len(df)] = [x] + row
print(df)
if plot_name:
df.plot(x=x_names[0], y=y_lines, ylabel=ylabel, xlabel=' = '.join(x_names), title=f'{plot_name}', loglog=loglog)
plt.savefig(f'{plot_name}.pdf')
def perf_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
import os
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
num_flops = 2*M*N*K
if provider == 'torch':
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
if provider == 'triton':
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row'
layout_b = 'column' if BT else 'row'
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \
'--profiling-iterations=50', f'--output={fname}', '--verbose=false']
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv')
cutlass_tflops = max(df_c['GFLOPs'])/1e3
return cutlass_tflops
return None
if __name__ == '__main__':
# # square
x_square = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nn')
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nt')
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tn')
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tt')
# tall-skinny
x_tall_skinny = [64, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536]
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=2048, K=2048, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-2k-2k')
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=4096, K=4096, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-4k-4k')
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=6144, K=6144, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-6k-6k')

View File

@@ -2,6 +2,7 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from . import testing
from .kernel import *
from . import ops
# C bindings

78
python/triton/testing.py Normal file
View File

@@ -0,0 +1,78 @@
import torch
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
return torch.allclose(x, y, atol=atol, rtol=rtol)
def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
torch.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
torch.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
self.x_names = x_names
self.x_vals = x_vals
self.y_name = y_name
self.y_vals = y_vals
self.y_lines = y_lines
self.ylabel = ylabel
self.loglog = loglog
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, result_path, with_plot):
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear')
plot.set_yscale('log' if bench.loglog else 'linear')
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
def run(self, result_path, with_plot):
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper