[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
except ImportError:
|
||||
_cutlass = None
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
@@ -9,6 +14,15 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def cutlass_matmul(a, b):
|
||||
if _cutlass is None:
|
||||
raise RuntimeError("Cannot find cutlass library")
|
||||
M, N = a.shape[0], b.shape[1]
|
||||
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||
_cutlass.matmul(a, b, c)
|
||||
return c
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
|
Reference in New Issue
Block a user