2021-01-29 17:27:16 -05:00
|
|
|
import pytest
|
|
|
|
import itertools
|
|
|
|
import triton as tt
|
|
|
|
import torch as th
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
|
|
|
[
|
|
|
|
# 1 warp
|
|
|
|
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
# 2 warp
|
|
|
|
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
# 4 warp
|
|
|
|
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
# 8 warp
|
|
|
|
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
|
|
|
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
|
|
|
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
|
|
|
# variable input
|
|
|
|
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
|
|
|
]
|
|
|
|
for DTYPE in ['float16']
|
|
|
|
for AT in [False, True]
|
|
|
|
for BT in [False, True]
|
|
|
|
]))
|
|
|
|
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
|
|
|
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
|
|
|
th.manual_seed(0)
|
|
|
|
tt.ops._matmul.kernel = dict()
|
|
|
|
tt.ops._matmul.TM = [TM]
|
|
|
|
tt.ops._matmul.TN = [TN]
|
|
|
|
tt.ops._matmul.TK = [TK]
|
|
|
|
tt.ops._matmul.num_warps = [NWARP]
|
|
|
|
if M is None: M = TM
|
|
|
|
if N is None: N = TN
|
|
|
|
if K is None: K = TK
|
|
|
|
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
|
|
|
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
|
|
|
a = a.t() if AT else a
|
|
|
|
b = b.t() if BT else b
|
|
|
|
th_c = th.matmul(a, b)
|
|
|
|
tt_c = tt.ops.matmul(a, b)
|
|
|
|
rtol, atol = {th.float32: (1e-4, 1e-5),
|
|
|
|
th.float16: (1e-2, 1e-3)}[DTYPE]
|
|
|
|
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
|
|
|
|
|
|
|
|
|
|
|
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
|
|
|
start_event = th.cuda.Event(enable_timing=True)
|
|
|
|
end_event = th.cuda.Event(enable_timing=True)
|
|
|
|
ret = fn()
|
|
|
|
for i in range(warmup):
|
|
|
|
fn()
|
|
|
|
th.cuda.synchronize()
|
|
|
|
start_event.record()
|
|
|
|
for i in range(rep):
|
|
|
|
fn()
|
|
|
|
end_event.record()
|
|
|
|
th.cuda.synchronize()
|
|
|
|
time_ms = start_event.elapsed_time(end_event) / rep
|
2021-01-31 14:17:27 -05:00
|
|
|
return time_ms
|
2021-01-29 17:27:16 -05:00
|
|
|
|
|
|
|
|
|
|
|
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
2021-01-31 12:59:18 -05:00
|
|
|
import pandas as pd
|
2021-01-31 21:23:01 -05:00
|
|
|
import os
|
2021-01-31 14:17:27 -05:00
|
|
|
AT, BT = False, False
|
2021-01-31 21:23:01 -05:00
|
|
|
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
|
|
|
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
|
|
|
|
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
2021-01-31 12:59:18 -05:00
|
|
|
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
|
|
|
for AT, BT, M, N, K in configs:
|
2021-01-29 17:27:16 -05:00
|
|
|
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
|
|
|
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
|
|
|
if AT: a = a.t()
|
|
|
|
if BT: b = b.t()
|
2021-01-31 14:17:27 -05:00
|
|
|
# benchmarks
|
|
|
|
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
|
|
|
|
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
|
|
|
|
# store result
|
|
|
|
num_flops = 2*M*N*K
|
|
|
|
torch_tflops = num_flops / torch_ms * 1e-9
|
|
|
|
triton_tflops = num_flops / triton_ms * 1e-9
|
2021-01-31 21:23:01 -05:00
|
|
|
if 'CUTLASS_PROFILER' in os.environ:
|
|
|
|
import subprocess
|
|
|
|
# run program specified by CUTLASS_PROFILER env variable
|
|
|
|
layout_a = 'column' if AT else 'row'
|
|
|
|
layout_b = 'column' if BT else 'row'
|
|
|
|
# create temporary file name
|
|
|
|
import tempfile
|
|
|
|
fd, fname = tempfile.mkstemp()
|
|
|
|
# run program and gets its output
|
|
|
|
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
|
|
|
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \
|
|
|
|
'--profiling-iterations=50', f'--output={fname}', '--verbose=false']
|
|
|
|
# run cmd
|
|
|
|
subprocess.run(cmd, stdout=subprocess.PIPE)
|
|
|
|
# read CSV output
|
|
|
|
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
|
|
|
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
|
|
|
else:
|
|
|
|
cutlass_tflops = None
|
|
|
|
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
2021-01-31 14:17:27 -05:00
|
|
|
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
2021-01-31 12:59:18 -05:00
|
|
|
print(df)
|