Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
## -----------------------------------------------------------------------------
## Unit Tests
## -----------------------------------------------------------------------------
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
## -----------------------------------------------------------------------------
## Performance Tests
## -----------------------------------------------------------------------------
def do_bench(fn, warmup = 10, rep = 50):
import torch as th
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
# create layout
M, N, K = N, N, N
shape = {'sdd': (M, N),
'dsd': (K, M) if TRANS_A else (M, K),
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
# create op
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
# inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
triton_tflops = num_flops / ms * 1e3
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
layout = make_layout(H, N//BLOCK, N//BLOCK)
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK)
op = tt.ops.blocksparse.softmax(layout, BLOCK)
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
nbytes = 2 * a.numel() * a.element_size()
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
print(triton_gbyps)