Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
## -----------------------------------------------------------------------------
## Unit Tests
## -----------------------------------------------------------------------------
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
## -----------------------------------------------------------------------------
## Performance Tests
## -----------------------------------------------------------------------------
def do_bench(fn, warmup = 10, rep = 50):
import torch as th
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
# create layout
M, N, K = N, N, N
shape = {'sdd': (M, N),
'dsd': (K, M) if TRANS_A else (M, K),
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
# create op
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
# inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
triton_tflops = num_flops / ms * 1e3
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
layout = make_layout(H, N//BLOCK, N//BLOCK)
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK)
op = tt.ops.blocksparse.softmax(layout, BLOCK)
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
nbytes = 2 * a.numel() * a.element_size()
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
print(triton_gbyps)

View File

@@ -3,57 +3,58 @@ import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
tt.ops._matmul._kernels = dict()
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
if K is None: K = TK*TZ
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
return time_ms
def perf_op(dtype=th.float16, warmup=10, rep=50):
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
import pandas as pd
import matplotlib.pyplot as plt
import os
AT, BT = False, False
has_cutlass = 'CUTLASS_PROFILER' in os.environ
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
for AT, BT, M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
cutlass_tflops = max(df_c['GFLOPs'])/1e3
else:
cutlass_tflops = None
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
print(df)
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
# name
AT = {True: 'T', False: 'N'}[AT]
BT = {True: 'T', False: 'N'}[BT]
name = f'{AT}{BT}'
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
plt.savefig(f'matmul-{mode}-{name}.pdf')