Improvements w/ Auto-Tuning and standard benchmarks (#57)
[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
committed by
Philippe Tillet
parent
ad005d49ac
commit
6fb4800f57
@@ -3,57 +3,58 @@ import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul.kernel = dict()
|
||||
tt.ops._matmul.TM = [TM]
|
||||
tt.ops._matmul.TN = [TN]
|
||||
tt.ops._matmul.TK = [TK]
|
||||
tt.ops._matmul.num_warps = [NWARP]
|
||||
tt.ops._matmul._kernels = dict()
|
||||
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK
|
||||
if K is None: K = TK*TZ
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
return time_ms
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
AT, BT = False, False
|
||||
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
|
||||
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
||||
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
|
||||
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
|
||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||
for AT, BT, M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||
else:
|
||||
cutlass_tflops = None
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
||||
print(df)
|
||||
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||
# name
|
||||
AT = {True: 'T', False: 'N'}[AT]
|
||||
BT = {True: 'T', False: 'N'}[BT]
|
||||
name = f'{AT}{BT}'
|
||||
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
|
||||
plt.savefig(f'matmul-{mode}-{name}.pdf')
|
Reference in New Issue
Block a user