diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py new file mode 100644 index 000000000..754c79c79 --- /dev/null +++ b/python/bench/bench_blocksparse.py @@ -0,0 +1,87 @@ +import torch +import triton + +# ------------------------------- +# Matrix Multiplication +# ------------------------------- + +nt = {False: 'n', True: 't'} +square_confs = [ + triton.testing.Benchmark( + x_names = ['M', 'N', 'K'], + x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], + y_name = 'block', + y_vals = [16, 32, 64], + y_lines = ['Block16', 'Block32', 'Block64'], + ylabel = 'TFLOPS', + loglog = False, + plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', + args = {'layout_mode': layout_mode, 'op_mode': op_mode, + 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} + )\ + for AT in [False] for BT in [False] \ + for op_mode in ['sdd', 'dsd', 'dds'] for layout_mode in ['tril', 'dense'] +] + +@triton.testing.perf_report(square_confs) +def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=5, rep=5): + Z, H = 1, 1 + make_layout = { + 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ + 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), + }[layout_mode] + # create layout + shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode] + layout = make_layout(H, shape[0] // block, shape[1] // block) + # creat inputs + a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda') + b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda') + # create op + if provider == 'triton': + op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT) + # inputs + a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a + b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b + ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) + num_flops = { + 'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ + 'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ + 'dds': 2 * Z * M * float(layout.sum()) * block * block + }[op_mode]*1e-12 + triton_tflops = num_flops / ms * 1e3 + return triton_tflops + +# ------------------------------- +# Softmax +# ------------------------------- + +square_confs = [ + triton.testing.Benchmark( + x_names = ['M', 'N'], + x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], + y_name = 'block', + y_vals = [16, 32, 64], + y_lines = ['Block16', 'Block32', 'Block64'], + ylabel = 'GBPS', + loglog = False, + plot_name = f'{layout_mode}-square', + args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} + )\ + for layout_mode in ['dense', 'tril'] +] + +@triton.testing.perf_report(square_confs) +def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): + Z, H = 1, 1 + make_layout = { + 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), + 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), + }[layout_mode] + layout = make_layout(H, M // block, N // block) + a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda') + if provider == 'triton': + a = triton.testing.sparsify_tensor(a, layout, block) + op = triton.ops.blocksparse.softmax(layout, block) + ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) + gbps = (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) + return gbps \ No newline at end of file diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py new file mode 100644 index 000000000..7d98fad9f --- /dev/null +++ b/python/bench/bench_cross_entropy.py @@ -0,0 +1,37 @@ +import torch +import triton + +confs = [ + triton.testing.Benchmark( + x_names = ['N'], + x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], + y_name = 'provider', + y_vals = ['triton', 'torch'], + y_lines = ['Triton', 'Torch'], + ylabel = 'GBPS', + loglog = False, + plot_name = f'{mode}-2048', + args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} + )\ + for mode in ['forward', 'backward'] +] + +@triton.testing.perf_report(confs) +def bench_op(M, N, dtype, mode, provider): + # create inputs + x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') + num_gb = (2 * x.numel() * x.element_size() * 1e-9) + # forward pass + op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ + 'triton': triton.ops.cross_entropy}[provider] + if mode == 'forward': + ms = triton.testing.do_bench(lambda: op(x, idx)) + if mode == 'backward': + y = op(x, idx) + dy = torch.randn_like(y) + ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True)) + return num_gb / ms * 1e3 + +if __name__ == '__main__': + bench_op.run('tmp', False) \ No newline at end of file diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py new file mode 100644 index 000000000..457ebf4a6 --- /dev/null +++ b/python/bench/bench_matmul.py @@ -0,0 +1,59 @@ +import triton +import torch + +# square benchmarks +nt = {False: 'n', True: 't'} +square_confs = [ + triton.testing.Benchmark( + x_names = ['M', 'N', 'K'], + x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], + y_name = 'provider', + y_vals = ['torch', 'triton', 'cutlass'], + y_lines = ['Torch', 'Triton', 'CUTLASS'], + ylabel = 'TFLOPS', + loglog = False, + plot_name = f'matmul-square-{nt[AT]}{nt[BT]}', + args = {'AT': False, 'BT': False, 'dtype': torch.float16} + )\ + for AT in [False, True] for BT in [False, True] +] + +@triton.testing.perf_report(square_confs) +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5): + import os + a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 + b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 + if AT: a = a.t() + if BT: b = b.t() + num_flops = 2 * M * N * K + if provider == 'torch': + torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) + torch_tflops = num_flops / torch_ms * 1e-9 + return torch_tflops + if provider == 'triton': + triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) + triton_tflops = num_flops / triton_ms * 1e-9 + return triton_tflops + if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: + import subprocess + import tempfile + import pandas as pd + # run program specified by CUTLASS_PROFILER env variable + layout_a = 'column' if AT else 'row' + layout_b = 'column' if BT else 'row' + # create temporary file name + fd, fname = tempfile.mkstemp() + # run program and gets its output + cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ + '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \ + f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false'] + # run cmd + subprocess.run(cmd, stdout=subprocess.PIPE) + # read CSV output + df_c = pd.read_csv(f'{fname}.gemm.csv') + cutlass_tflops = max(df_c['GFLOPs']) / 1e3 + return cutlass_tflops + return None + +if __name__ == '__main__': + bench_op.run() diff --git a/python/setup.py b/python/setup.py index 6257a239c..b85276dd7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,31 +14,27 @@ from setuptools.command.test import test as TestCommand import distutils.spawn import torch - def find_llvm(): - versions = ['-10', '-9.0', '-9', '-90', '-8.0', '-8', '-80', ''] + versions = ['-10', '-10.0', ''] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] if paths: - return paths[0] + return paths[0] config = distutils.spawn.find_executable('llvm-config') - instructions = 'Please install llvm-{8, 9, 10}-dev' + instructions = 'Please install llvm-10-dev' if config is None: raise RuntimeError('Could not find llvm-config. ' + instructions) version = os.popen('{config} --version'.format(config=config)).read() raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions) - class CMakeExtension(Extension): def __init__(self, name, path, sourcedir=''): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) self.path = path - class CMakeBuild(build_ext): - def run(self): try: out = subprocess.check_output(['cmake', '--version']) @@ -63,16 +59,18 @@ class CMakeBuild(build_ext): torch_include_dirs = include_paths(True) torch_library_dirs = library_paths(True) cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) - cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, - '-DBUILD_TUTORIALS=OFF', - '-DBUILD_PYTHON_MODULE=ON', - #'-DPYTHON_EXECUTABLE=' + sys.executable, - #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, - '-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), - '-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), - '-DTORCH_CXX11_ABI=' + cxx11abi, - '-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', - '-DLLVM_CONFIG=' + find_llvm()] + cmake_args = [ + '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, + '-DBUILD_TUTORIALS=OFF', + '-DBUILD_PYTHON_MODULE=ON', + #'-DPYTHON_EXECUTABLE=' + sys.executable, + #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, + '-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), + '-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), + '-DTORCH_CXX11_ABI=' + cxx11abi, + '-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', + '-DLLVM_CONFIG=' + find_llvm() + ] # configuration cfg = 'Debug' if self.debug else 'Release' cfg = 'Release' @@ -90,10 +88,54 @@ class CMakeBuild(build_ext): env = os.environ.copy() if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) + sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) +class BenchCommand(distutils.cmd.Command): + + description = 'run benchmark suite' + user_options = [ + ('result-dir=', None, 'path to output benchmark results'),\ + ('with-plots', None, 'plot benchmark results'),\ + ('filter=' , None, 'filter benchmarks by name') + ] + + def initialize_options(self): + self.result_dir = 'results' + self.filter = '' + self.with_plots = False + + def finalize_options(self): + if not os.path.exists(self.result_dir): + os.makedirs(self.result_dir) + + def run(self): + import sys + import inspect + import triton + bench_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bench') + sys.path.append(bench_dir) + for mod in os.listdir(bench_dir): + # skip non python files + if not mod.endswith('.py'): + continue + # skip file not in provided filter + if self.filter and self.filter not in mod: + continue + # skip files that don't start with 'bench_' + if not mod.startswith('bench_'): + continue + print(f'running {mod}...') + mod = __import__(os.path.splitext(mod)[0]) + benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark)) + for name, bench in benchmarks: + result_dir = os.path.join(self.result_dir, mod.__name__.replace('bench_', '')) + if len(benchmarks) > 1: + result_dir = os.path.join(result_dir, name.replace('bench_', '')) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + bench.run(result_dir, self.with_plots) setup( name='triton', @@ -104,21 +146,20 @@ setup( long_description='', packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'], install_requires=['numpy', 'torch'], - package_data={'triton/ops': ['*.c'], - 'triton/ops/blocksparse': ['*.c']}, + package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']}, include_package_data=True, ext_modules=[CMakeExtension('triton', 'triton/_C/')], - cmdclass=dict(build_ext=CMakeBuild), + cmdclass={'build_ext': CMakeBuild, 'bench': BenchCommand}, zip_safe=False, # for PyPI - keyword=['Compiler', 'Deep Learning'], + keywords=['Compiler', 'Deep Learning'], url='https://github.com/ptillet/triton/', download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', classifiers=[ - 'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package - 'Intended Audience :: Developers', # Define that your audience are developers - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: MIT License', # Again, pick a license - 'Programming Language :: Python :: 3.6', - ], + 'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + 'Intended Audience :: Developers', # Define that your audience are developers + 'Topic :: Software Development :: Build Tools', + 'License :: OSI Approved :: MIT License', # Again, pick a license + 'Programming Language :: Python :: 3.6', + ], ) diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py new file mode 100644 index 000000000..0dd793c3b --- /dev/null +++ b/python/test/test_blocksparse.py @@ -0,0 +1,78 @@ +import torch +import triton +import pytest + + +@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", + [ + (mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ + for at in [False, True]\ + for bt in [False, True]\ + for block in [16, 32, 64] + ] + ) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384): + # set seed + torch.random.manual_seed(0) + # create inputs + a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') + b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') + shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + # triton result + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) + ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a + rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b + rc = op(ra, rb) + # torch result + ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a + tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b + ta = ta.transpose(2, 3) if TRANS_A else ta + tb = tb.transpose(2, 3) if TRANS_B else tb + tc = torch.matmul(ta, tb) + tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc + tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc + # compare + assert triton.testing.allclose(rc, tc) + +@pytest.mark.parametrize("BLOCK, WIDTH", + [ + (block, width) for block in [32]\ + for width in [256, 576, 1024, 1792] + ] + ) +def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 4, WIDTH, WIDTH + scale = 0.4 + # create inputs + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda') + at_mask = torch.randint(low=0, high=2, size=(N, N), \ + dtype=torch.bool, requires_grad=False, device='cuda') + kp_mask = torch.randint(low=0, high=2, size=(Z, N), \ + dtype=DTYPE, requires_grad=False, device='cuda') + kp_mask[kp_mask == 1.] = float('-inf') + # triton result + op = triton.ops.blocksparse.softmax(layout, BLOCK) + tx = triton.testing.sparsify_tensor(x, layout, BLOCK) + ty = op(tx, + scale=scale, + key_padding_mask=kp_mask, + key_padding_mask_mode='add', + attn_mask=at_mask.to(DTYPE), + attn_mask_mode='mul') + # torch result + rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float('-inf')) + if at_mask is not None: + # broadcast at_mask to the same shape as rx + M = at_mask[None, None, :, :] + torch.zeros_like(rx) + rx[M == 0] = float('-inf') + if kp_mask is not None: + rx += kp_mask[:, None, None, :] + ry = torch.softmax(rx * scale, -1) + ry = torch.softmax(rx * scale, -1) + ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) + # compare + assert triton.testing.allclose(ry, ty) \ No newline at end of file diff --git a/python/tests/test_conv.py b/python/test/test_conv.py similarity index 100% rename from python/tests/test_conv.py rename to python/test/test_conv.py diff --git a/python/tests/test_cross_entropy.py b/python/test/test_cross_entropy.py similarity index 100% rename from python/tests/test_cross_entropy.py rename to python/test/test_cross_entropy.py diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py new file mode 100644 index 000000000..57e45849d --- /dev/null +++ b/python/test/test_matmul.py @@ -0,0 +1,63 @@ +import pytest +import itertools +import triton +import torch + +@pytest.mark.parametrize( + "TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", + itertools.chain(*[ + [ + # 1 warp + (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), + # 2 warp + (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), + # 4 warp + (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), + # 8 warp + (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), + # split-k + (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), + # variable input + (128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE) + ] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True] + ])) +def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): + DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE] + torch.manual_seed(0) + triton.ops._matmul._kernels = dict() + triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] + if M is None: M = TM + if N is None: N = TN + if K is None: K = TK * TZ + a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 + b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 + a = a.t() if AT else a + b = b.t() if BT else b + th_c = torch.matmul(a, b) + tt_c = triton.ops.matmul(a, b) + assert triton.testing.allclose(th_c, tt_c) \ No newline at end of file diff --git a/python/tests/test_blocksparse.py b/python/tests/test_blocksparse.py deleted file mode 100644 index f3686227b..000000000 --- a/python/tests/test_blocksparse.py +++ /dev/null @@ -1,160 +0,0 @@ -import itertools -import torch -import triton as tt -import pytest - -def sparsify_tensor(x, mask, block): - ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) - for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): - ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block] - return ret - -def mask_tensor(x, mask, block, value = 0): - ret = x.clone() - for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): - ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value - return ret - - - -## ----------------------------------------------------------------------------- -## Unit Tests -## ----------------------------------------------------------------------------- - -@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", - [ - (mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ - for at in [False, True]\ - for bt in [False, True]\ - for block in [16, 32, 64] - ] -) -def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384): - # set seed - torch.random.manual_seed(0) - # create inputs - a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') - b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') - shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE] - layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK)) - # triton result - op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) - ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a - rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b - rc = op(ra, rb) - # torch result - ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a - tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b - ta = ta.transpose(2, 3) if TRANS_A else ta - tb = tb.transpose(2, 3) if TRANS_B else tb - tc = torch.matmul(ta, tb) - tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc - tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc - # compare - rtol, atol = {torch.float32: (1e-4, 1e-5), - torch.float16: (1e-2, 1e-3)}[DTYPE] - assert torch.allclose(rc, tc, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("BLOCK, WIDTH", - [ - (block, width) for block in [32]\ - for width in [256, 576, 1024, 2048, 4096] - ] -) -def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16): - # set seed - torch.random.manual_seed(0) - Z, H, M, N = 2, 4, WIDTH, WIDTH - scale = 0.4 - # create inputs - layout = torch.randint(2, (H, M//BLOCK, N//BLOCK)) - x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device='cuda') - at_mask = torch.randint(low=0, high=2, size=(N, N), \ - dtype=torch.bool, requires_grad=False, device='cuda') - kp_mask = torch.randint(low=0, high=2, size=(Z, N), \ - dtype=DTYPE, requires_grad=False, device='cuda') - kp_mask[kp_mask==1.] = float('-inf') - # triton result - op = tt.ops.blocksparse.softmax(layout, BLOCK) - tx = sparsify_tensor(x, layout, BLOCK) - ty = op(tx, scale=scale) - # torch result - rx = mask_tensor(x, layout, BLOCK, value=float('-inf')) - # if at_mask is not None: - # # broadcast at_mask to the same shape as rx - # M = at_mask[None, None, :, :] + torch.zeros_like(rx) - # rx[M == 0] = float('-inf') - # if kp_mask is not None: - # rx += kp_mask[:, None, None, :] - ry = torch.softmax(rx*scale, -1) - ry = sparsify_tensor(ry, layout, BLOCK) - # compare - rtol, atol = {torch.float32: (1e-4, 1e-5), - torch.float16: (1e-2, 1e-3)}[DTYPE] - assert torch.allclose(ry , ty, rtol=rtol, atol=atol) - - -## ----------------------------------------------------------------------------- -## Performance Tests -## ----------------------------------------------------------------------------- - -def do_bench(fn, warmup = 10, rep = 50): - import torch as th - start_event = th.cuda.Event(enable_timing=True) - end_event = th.cuda.Event(enable_timing=True) - ret = fn() - for i in range(warmup): - fn() - th.cuda.synchronize() - start_event.record() - for i in range(rep): - fn() - end_event.record() - th.cuda.synchronize() - time_ms = start_event.elapsed_time(end_event) / rep - return time_ms - -def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50): - Z, H = 1, 1 - K = 512 - make_layout = { - 'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), - 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), - }[LAYOUT_MODE] - for N in [128, 256, 512, 1024, 2048, 4096]: - # create layout - M, N, K = N, N, N - shape = {'sdd': (M, N), - 'dsd': (K, M) if TRANS_A else (M, K), - 'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE] - layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK) - # create op - op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B) - # inputs - a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') - b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') - a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a - b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b - ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep) - num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12, - 'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12, - 'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE] - triton_tflops = num_flops / ms * 1e3 - -def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50): - Z, H = 1, 1 - K = 512 - make_layout = { - 'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), - 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), - }[LAYOUT_MODE] - for N in [128, 256, 512, 1024, 2048, 4096]: - layout = make_layout(H, N//BLOCK, N//BLOCK) - a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda') - a = sparsify_tensor(a, layout, BLOCK) - op = tt.ops.blocksparse.softmax(layout, BLOCK) - ms = do_bench(lambda: op(a), warmup=warmup, rep=rep) - nbytes = 2 * a.numel() * a.element_size() - triton_gbyps = (nbytes*1e-9) / (ms*1e-3) - print(triton_gbyps) diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py deleted file mode 100644 index bc13a19c6..000000000 --- a/python/tests/test_matmul.py +++ /dev/null @@ -1,151 +0,0 @@ -import pytest -import itertools -import triton as tt -import torch as th - -@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[ - [ - # 1 warp - (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), - # 2 warp - (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), - # 4 warp - (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), - # 8 warp - (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), - # split-k - (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), - # variable input - (128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE), - (128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE) - ] - for DTYPE in ['float16'] - for AT in [False, True] - for BT in [False, True] -])) -def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): - DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE] - th.manual_seed(0) - tt.ops._matmul._kernels = dict() - tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] - if M is None: M = TM - if N is None: N = TN - if K is None: K = TK*TZ - a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 - b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 - a = a.t() if AT else a - b = b.t() if BT else b - th_c = th.matmul(a, b) - tt_c = tt.ops.matmul(a, b) - rtol, atol = {th.float32: (1e-4, 1e-5), - th.float16: (1e-2, 1e-3)}[DTYPE] - assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol) - - -def do_bench(fn, flops = 0, warmup = 10, rep = 50): - start_event = th.cuda.Event(enable_timing=True) - end_event = th.cuda.Event(enable_timing=True) - ret = fn() - for i in range(warmup): - fn() - th.cuda.synchronize() - start_event.record() - for i in range(rep): - fn() - end_event.record() - th.cuda.synchronize() - time_ms = start_event.elapsed_time(end_event) / rep - return time_ms - -def time_all(fn, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog=True, plot_name='', **kwargs): - import matplotlib.pyplot as plt - import pandas as pd - df = pd.DataFrame(columns = [x_names[0]] + y_lines) - for x in x_vals: - x_args = {x_name: x for x_name in x_names} - row = [fn(**x_args, **{y_name: y}, **kwargs) for y in y_vals] - df.loc[len(df)] = [x] + row - print(df) - if plot_name: - df.plot(x=x_names[0], y=y_lines, ylabel=ylabel, xlabel=' = '.join(x_names), title=f'{plot_name}', loglog=loglog) - plt.savefig(f'{plot_name}.pdf') - -def perf_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): - import os - a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 - b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 - if AT: a = a.t() - if BT: b = b.t() - num_flops = 2*M*N*K - if provider == 'torch': - torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep) - torch_tflops = num_flops / torch_ms * 1e-9 - return torch_tflops - if provider == 'triton': - triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep) - triton_tflops = num_flops / triton_ms * 1e-9 - return triton_tflops - if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: - import subprocess - import tempfile - import pandas as pd - # run program specified by CUTLASS_PROFILER env variable - layout_a = 'column' if AT else 'row' - layout_b = 'column' if BT else 'row' - # create temporary file name - fd, fname = tempfile.mkstemp() - # run program and gets its output - cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ - '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \ - '--profiling-iterations=50', f'--output={fname}', '--verbose=false'] - # run cmd - subprocess.run(cmd, stdout=subprocess.PIPE) - # read CSV output - df_c = pd.read_csv(f'{fname}.gemm.csv') - cutlass_tflops = max(df_c['GFLOPs'])/1e3 - return cutlass_tflops - return None - -if __name__ == '__main__': - # # square - x_square = [128, 256, 512, 1024, 2048, 3072, 4096, 6144] - time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nn') - time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nt') - time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tn') - time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tt') - # tall-skinny - x_tall_skinny = [64, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536] - time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=2048, K=2048, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-2k-2k') - time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=4096, K=4096, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-4k-4k') - time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], - ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=6144, K=6144, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-6k-6k') \ No newline at end of file diff --git a/python/triton/__init__.py b/python/triton/__init__.py index becddcc32..8f0c1668e 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -2,6 +2,7 @@ # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules +from . import testing from .kernel import * from . import ops # C bindings diff --git a/python/triton/testing.py b/python/triton/testing.py new file mode 100644 index 000000000..27812df56 --- /dev/null +++ b/python/triton/testing.py @@ -0,0 +1,78 @@ +import torch + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + return ret + +def mask_tensor(x, mask, block, value=0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + return ret + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype] + return torch.allclose(x, y, atol=atol, rtol=rtol) + +def do_bench(fn, flops=0, warmup=10, rep=50): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + ret = fn() + for i in range(warmup): + fn() + torch.cuda.synchronize() + start_event.record() + for i in range(rep): + fn() + end_event.record() + torch.cuda.synchronize() + time_ms = start_event.elapsed_time(end_event) / rep + return time_ms + +class Benchmark: + def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): + self.x_names = x_names + self.x_vals = x_vals + self.y_name = y_name + self.y_vals = y_vals + self.y_lines = y_lines + self.ylabel = ylabel + self.loglog = loglog + self.plot_name = plot_name + self.args = args + +class Mark: + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench, result_path, with_plot): + import matplotlib.pyplot as plt + import pandas as pd + import os + df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines) + for x in bench.x_vals: + x_args = {x_name: x for x_name in bench.x_names} + row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals] + df.loc[len(df)] = [x] + row + if with_plot and bench.plot_name: + xlabel = ' = '.join(bench.x_names) + plot = df.plot(x=bench.x_names[0], y=bench.y_lines) + plot.set_xlabel(xlabel) + plot.set_ylabel(bench.ylabel) + plot.set_title(bench.plot_name) + plot.set_xscale('log' if bench.loglog else 'linear') + plot.set_yscale('log' if bench.loglog else 'linear') + plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png')) + df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv')) + + def run(self, result_path, with_plot): + for bench in self.benchmarks: + self._run(bench, result_path, with_plot) + +def perf_report(benchmarks): + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper