[PYTHON] CUTLASS wrapper for fair benchmarks (#75)

Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
Philippe Tillet
2021-03-09 16:32:44 -05:00
committed by Philippe Tillet
parent d6f18742b1
commit eacbb73968
6 changed files with 257 additions and 41 deletions

View File

@@ -1,6 +1,11 @@
import torch
import os
try:
import triton._C.libtriton.cutlass as _cutlass
except ImportError:
_cutlass = None
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
@@ -9,6 +14,15 @@ def sparsify_tensor(x, mask, block):
return ret
def cutlass_matmul(a, b):
if _cutlass is None:
raise RuntimeError("Cannot find cutlass library")
M, N = a.shape[0], b.shape[1]
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
_cutlass.matmul(a, b, c)
return c
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):