[PYTHON] Added automated benchmark script (#63)
This adds a bench functionality to the setup.py that can be used to run the benchmark suite and generates a bunch of csv files (and optionally plots) python setup.py bench python setup.py bench --with-plots python setup.py bench --filter=cross_entropy
This commit is contained in:
committed by
Philippe Tillet
parent
66c94f21d7
commit
5e3c7f5a60
87
python/bench/bench_blocksparse.py
Normal file
87
python/bench/bench_blocksparse.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
# Matrix Multiplication
|
||||
# -------------------------------
|
||||
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'block',
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for AT in [False] for BT in [False] \
|
||||
for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
|
||||
layout = make_layout(H, shape[0] // block, shape[1] // block)
|
||||
# creat inputs
|
||||
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
|
||||
# create op
|
||||
if provider == 'triton':
|
||||
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
|
||||
# inputs
|
||||
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode]*1e-12
|
||||
triton_tflops = num_flops / ms * 1e3
|
||||
return triton_tflops
|
||||
|
||||
# -------------------------------
|
||||
# Softmax
|
||||
# -------------------------------
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'block',
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
layout = make_layout(H, M // block, N // block)
|
||||
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
||||
if provider == 'triton':
|
||||
a = triton.testing.sparsify_tensor(a, layout, block)
|
||||
op = triton.ops.blocksparse.softmax(layout, block)
|
||||
ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
gbps = (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
return gbps
|
37
python/bench/bench_cross_entropy.py
Normal file
37
python/bench/bench_cross_entropy.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['N'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
y_name = 'provider',
|
||||
y_vals = ['triton', 'torch'],
|
||||
y_lines = ['Triton', 'Torch'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{mode}-2048',
|
||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(confs)
|
||||
def bench_op(M, N, dtype, mode, provider):
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
if mode == 'backward':
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True))
|
||||
return num_gb / ms * 1e3
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run('tmp', False)
|
59
python/bench/bench_matmul.py
Normal file
59
python/bench/bench_matmul.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import triton
|
||||
import torch
|
||||
|
||||
# square benchmarks
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'provider',
|
||||
y_vals = ['torch', 'triton', 'cutlass'],
|
||||
y_lines = ['Torch', 'Triton', 'CUTLASS'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
|
||||
)\
|
||||
for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
import os
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == 'torch':
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
|
||||
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run()
|
@@ -14,31 +14,27 @@ from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import torch
|
||||
|
||||
|
||||
def find_llvm():
|
||||
versions = ['-10', '-9.0', '-9', '-90', '-8.0', '-8', '-80', '']
|
||||
versions = ['-10', '-10.0', '']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return paths[0]
|
||||
return paths[0]
|
||||
config = distutils.spawn.find_executable('llvm-config')
|
||||
instructions = 'Please install llvm-{8, 9, 10}-dev'
|
||||
instructions = 'Please install llvm-10-dev'
|
||||
if config is None:
|
||||
raise RuntimeError('Could not find llvm-config. ' + instructions)
|
||||
version = os.popen('{config} --version'.format(config=config)).read()
|
||||
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, path, sourcedir=''):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
self.path = path
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(['cmake', '--version'])
|
||||
@@ -63,16 +59,18 @@ class CMakeBuild(build_ext):
|
||||
torch_include_dirs = include_paths(True)
|
||||
torch_library_dirs = library_paths(True)
|
||||
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_TUTORIALS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
|
||||
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
|
||||
'-DTORCH_CXX11_ABI=' + cxx11abi,
|
||||
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
|
||||
'-DLLVM_CONFIG=' + find_llvm()]
|
||||
cmake_args = [
|
||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_TUTORIALS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
|
||||
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
|
||||
'-DTORCH_CXX11_ABI=' + cxx11abi,
|
||||
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
|
||||
'-DLLVM_CONFIG=' + find_llvm()
|
||||
]
|
||||
# configuration
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
cfg = 'Release'
|
||||
@@ -90,10 +88,54 @@ class CMakeBuild(build_ext):
|
||||
env = os.environ.copy()
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
||||
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||
|
||||
class BenchCommand(distutils.cmd.Command):
|
||||
|
||||
description = 'run benchmark suite'
|
||||
user_options = [
|
||||
('result-dir=', None, 'path to output benchmark results'),\
|
||||
('with-plots', None, 'plot benchmark results'),\
|
||||
('filter=' , None, 'filter benchmarks by name')
|
||||
]
|
||||
|
||||
def initialize_options(self):
|
||||
self.result_dir = 'results'
|
||||
self.filter = ''
|
||||
self.with_plots = False
|
||||
|
||||
def finalize_options(self):
|
||||
if not os.path.exists(self.result_dir):
|
||||
os.makedirs(self.result_dir)
|
||||
|
||||
def run(self):
|
||||
import sys
|
||||
import inspect
|
||||
import triton
|
||||
bench_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bench')
|
||||
sys.path.append(bench_dir)
|
||||
for mod in os.listdir(bench_dir):
|
||||
# skip non python files
|
||||
if not mod.endswith('.py'):
|
||||
continue
|
||||
# skip file not in provided filter
|
||||
if self.filter and self.filter not in mod:
|
||||
continue
|
||||
# skip files that don't start with 'bench_'
|
||||
if not mod.startswith('bench_'):
|
||||
continue
|
||||
print(f'running {mod}...')
|
||||
mod = __import__(os.path.splitext(mod)[0])
|
||||
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
|
||||
for name, bench in benchmarks:
|
||||
result_dir = os.path.join(self.result_dir, mod.__name__.replace('bench_', ''))
|
||||
if len(benchmarks) > 1:
|
||||
result_dir = os.path.join(result_dir, name.replace('bench_', ''))
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
bench.run(result_dir, self.with_plots)
|
||||
|
||||
setup(
|
||||
name='triton',
|
||||
@@ -104,21 +146,20 @@ setup(
|
||||
long_description='',
|
||||
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
|
||||
install_requires=['numpy', 'torch'],
|
||||
package_data={'triton/ops': ['*.c'],
|
||||
'triton/ops/blocksparse': ['*.c']},
|
||||
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
cmdclass=dict(build_ext=CMakeBuild),
|
||||
cmdclass={'build_ext': CMakeBuild, 'bench': BenchCommand},
|
||||
zip_safe=False,
|
||||
# for PyPI
|
||||
keyword=['Compiler', 'Deep Learning'],
|
||||
keywords=['Compiler', 'Deep Learning'],
|
||||
url='https://github.com/ptillet/triton/',
|
||||
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
'Intended Audience :: Developers', # Define that your audience are developers
|
||||
'Topic :: Software Development :: Build Tools',
|
||||
'License :: OSI Approved :: MIT License', # Again, pick a license
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
'Intended Audience :: Developers', # Define that your audience are developers
|
||||
'Topic :: Software Development :: Build Tools',
|
||||
'License :: OSI Approved :: MIT License', # Again, pick a license
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
|
78
python/test/test_blocksparse.py
Normal file
78
python/test/test_blocksparse.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
for at in [False, True]\
|
||||
for bt in [False, True]\
|
||||
for block in [16, 32, 64]
|
||||
]
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
rc = op(ra, rb)
|
||||
# torch result
|
||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
# compare
|
||||
assert triton.testing.allclose(rc, tc)
|
||||
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH",
|
||||
[
|
||||
(block, width) for block in [32]\
|
||||
for width in [256, 576, 1024, 1792]
|
||||
]
|
||||
)
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), \
|
||||
dtype=torch.bool, requires_grad=False, device='cuda')
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
|
||||
dtype=DTYPE, requires_grad=False, device='cuda')
|
||||
kp_mask[kp_mask == 1.] = float('-inf')
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(tx,
|
||||
scale=scale,
|
||||
key_padding_mask=kp_mask,
|
||||
key_padding_mask_mode='add',
|
||||
attn_mask=at_mask.to(DTYPE),
|
||||
attn_mask_mode='mul')
|
||||
# torch result
|
||||
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
||||
if at_mask is not None:
|
||||
# broadcast at_mask to the same shape as rx
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
rx[M == 0] = float('-inf')
|
||||
if kp_mask is not None:
|
||||
rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = torch.softmax(rx * scale, -1)
|
||||
ry = triton.testing.sparsify_tensor(ry, layout, BLOCK)
|
||||
# compare
|
||||
assert triton.testing.allclose(ry, ty)
|
63
python/test/test_matmul.py
Normal file
63
python/test/test_matmul.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton
|
||||
import torch
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
|
||||
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
||||
torch.manual_seed(0)
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK * TZ
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
@@ -1,160 +0,0 @@
|
||||
import itertools
|
||||
import torch
|
||||
import triton as tt
|
||||
import pytest
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
|
||||
return ret
|
||||
|
||||
def mask_tensor(x, mask, block, value = 0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Unit Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
for at in [False, True]\
|
||||
for bt in [False, True]\
|
||||
for block in [16, 32, 64]
|
||||
]
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
rc = op(ra, rb)
|
||||
# torch result
|
||||
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH",
|
||||
[
|
||||
(block, width) for block in [32]\
|
||||
for width in [256, 576, 1024, 2048, 4096]
|
||||
]
|
||||
)
|
||||
def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 4, WIDTH, WIDTH
|
||||
scale = 0.4
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M//BLOCK, N//BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda')
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), \
|
||||
dtype=torch.bool, requires_grad=False, device='cuda')
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), \
|
||||
dtype=DTYPE, requires_grad=False, device='cuda')
|
||||
kp_mask[kp_mask==1.] = float('-inf')
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||
tx = sparsify_tensor(x, layout, BLOCK)
|
||||
ty = op(tx, scale=scale)
|
||||
# torch result
|
||||
rx = mask_tensor(x, layout, BLOCK, value=float('-inf'))
|
||||
# if at_mask is not None:
|
||||
# # broadcast at_mask to the same shape as rx
|
||||
# M = at_mask[None, None, :, :] + torch.zeros_like(rx)
|
||||
# rx[M == 0] = float('-inf')
|
||||
# if kp_mask is not None:
|
||||
# rx += kp_mask[:, None, None, :]
|
||||
ry = torch.softmax(rx*scale, -1)
|
||||
ry = sparsify_tensor(ry, layout, BLOCK)
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## Performance Tests
|
||||
## -----------------------------------------------------------------------------
|
||||
|
||||
def do_bench(fn, warmup = 10, rep = 50):
|
||||
import torch as th
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
# create layout
|
||||
M, N, K = N, N, N
|
||||
shape = {'sdd': (M, N),
|
||||
'dsd': (K, M) if TRANS_A else (M, K),
|
||||
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
|
||||
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
|
||||
# create op
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
# inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
|
||||
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
|
||||
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
|
||||
triton_tflops = num_flops / ms * 1e3
|
||||
|
||||
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
K = 512
|
||||
make_layout = {
|
||||
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[LAYOUT_MODE]
|
||||
for N in [128, 256, 512, 1024, 2048, 4096]:
|
||||
layout = make_layout(H, N//BLOCK, N//BLOCK)
|
||||
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
|
||||
a = sparsify_tensor(a, layout, BLOCK)
|
||||
op = tt.ops.blocksparse.softmax(layout, BLOCK)
|
||||
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
nbytes = 2 * a.numel() * a.element_size()
|
||||
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
|
||||
print(triton_gbyps)
|
@@ -1,151 +0,0 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul._kernels = dict()
|
||||
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK*TZ
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = th.matmul(a, b)
|
||||
tt_c = tt.ops.matmul(a, b)
|
||||
rtol, atol = {th.float32: (1e-4, 1e-5),
|
||||
th.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
def time_all(fn, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog=True, plot_name='', **kwargs):
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(columns = [x_names[0]] + y_lines)
|
||||
for x in x_vals:
|
||||
x_args = {x_name: x for x_name in x_names}
|
||||
row = [fn(**x_args, **{y_name: y}, **kwargs) for y in y_vals]
|
||||
df.loc[len(df)] = [x] + row
|
||||
print(df)
|
||||
if plot_name:
|
||||
df.plot(x=x_names[0], y=y_lines, ylabel=ylabel, xlabel=' = '.join(x_names), title=f'{plot_name}', loglog=loglog)
|
||||
plt.savefig(f'{plot_name}.pdf')
|
||||
|
||||
def perf_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
import os
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2*M*N*K
|
||||
if provider == 'torch':
|
||||
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \
|
||||
'--profiling-iterations=50', f'--output={fname}', '--verbose=false']
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
# # square
|
||||
x_square = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nn')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nt')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tn')
|
||||
time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tt')
|
||||
# tall-skinny
|
||||
x_tall_skinny = [64, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536]
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=2048, K=2048, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-2k-2k')
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=4096, K=4096, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-4k-4k')
|
||||
time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'],
|
||||
ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=6144, K=6144, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-6k-6k')
|
@@ -2,6 +2,7 @@
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from . import testing
|
||||
from .kernel import *
|
||||
from . import ops
|
||||
# C bindings
|
||||
|
78
python/triton/testing.py
Normal file
78
python/triton/testing.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||
return ret
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
def allclose(x, y):
|
||||
assert x.dtype == y.dtype
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
|
||||
return torch.allclose(x, y, atol=atol, rtol=rtol)
|
||||
|
||||
def do_bench(fn, flops=0, warmup=10, rep=50):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
||||
self.x_names = x_names
|
||||
self.x_vals = x_vals
|
||||
self.y_name = y_name
|
||||
self.y_vals = y_vals
|
||||
self.y_lines = y_lines
|
||||
self.ylabel = ylabel
|
||||
self.loglog = loglog
|
||||
self.plot_name = plot_name
|
||||
self.args = args
|
||||
|
||||
class Mark:
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, result_path, with_plot):
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
||||
for x in bench.x_vals:
|
||||
x_args = {x_name: x for x_name in bench.x_names}
|
||||
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
|
||||
df.loc[len(df)] = [x] + row
|
||||
if with_plot and bench.plot_name:
|
||||
xlabel = ' = '.join(bench.x_names)
|
||||
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
||||
plot.set_xlabel(xlabel)
|
||||
plot.set_ylabel(bench.ylabel)
|
||||
plot.set_title(bench.plot_name)
|
||||
plot.set_xscale('log' if bench.loglog else 'linear')
|
||||
plot.set_yscale('log' if bench.loglog else 'linear')
|
||||
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
|
||||
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
|
||||
|
||||
def run(self, result_path, with_plot):
|
||||
for bench in self.benchmarks:
|
||||
self._run(bench, result_path, with_plot)
|
||||
|
||||
def perf_report(benchmarks):
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
Reference in New Issue
Block a user