Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) {
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt,
const rt::options_t& opt,
const rt::function::autotune_vals_t& autotune_vals,
const std::vector<std::string>& autotune_key) {
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
@@ -62,7 +63,7 @@ void register_fn(int op_id,
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key));
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
@@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) {
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def_readwrite("defines" , &rt::options_t::defines);
pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("defines" , &rt::options_space_t::defines);
.def_readwrite("defines" , &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("extract_kernels", &extract_kernels);

View File

@@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
## -----------------------------------------------------------------------------
## Unit Tests
## -----------------------------------------------------------------------------
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
@@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16):
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(ry , ty, rtol=rtol, atol=atol)
## -----------------------------------------------------------------------------
## Performance Tests
## -----------------------------------------------------------------------------
def do_bench(fn, warmup = 10, rep = 50):
import torch as th
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
# create layout
M, N, K = N, N, N
shape = {'sdd': (M, N),
'dsd': (K, M) if TRANS_A else (M, K),
'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE]
layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK)
# create op
op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B)
# inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a
b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b
ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12,
'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE]
triton_tflops = num_flops / ms * 1e3
def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50):
Z, H = 1, 1
K = 512
make_layout = {
'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[LAYOUT_MODE]
for N in [128, 256, 512, 1024, 2048, 4096]:
layout = make_layout(H, N//BLOCK, N//BLOCK)
a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda')
a = sparsify_tensor(a, layout, BLOCK)
op = tt.ops.blocksparse.softmax(layout, BLOCK)
ms = do_bench(lambda: op(a), warmup=warmup, rep=rep)
nbytes = 2 * a.numel() * a.element_size()
triton_gbyps = (nbytes*1e-9) / (ms*1e-3)
print(triton_gbyps)

View File

@@ -3,57 +3,58 @@ import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
(128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
tt.ops._matmul._kernels = dict()
tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
if K is None: K = TK*TZ
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
@@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
return time_ms
def perf_op(dtype=th.float16, warmup=10, rep=50):
def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50):
import pandas as pd
import matplotlib.pyplot as plt
import os
AT, BT = False, False
has_cutlass = 'CUTLASS_PROFILER' in os.environ
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS'])
Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144]
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
for AT, BT, M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
@@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
cutlass_tflops = max(df_c['GFLOPs'])/1e3
else:
cutlass_tflops = None
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
print(df)
df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
# name
AT = {True: 'T', False: 'N'}[AT]
BT = {True: 'T', False: 'N'}[BT]
name = f'{AT}{BT}'
df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green'])
plt.savefig(f'matmul-{mode}-{name}.pdf')

View File

@@ -26,10 +26,8 @@ def th_to_triton(obj):
torch.float64: 'double'
}
if isinstance(obj, torch.dtype):
return [tys[obj]]
if isinstance(obj, list):
return [th_to_triton(x)[0] for x in obj]
return [str(obj)]
return tys[obj]
return str(obj)
def cdiv(a, b):
return libtriton.cdiv(a, b)
@@ -45,17 +43,15 @@ def read(path, kernel_names=[]):
source = libtriton.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []):
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
self.opt = libtriton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
@@ -65,7 +61,7 @@ class kernel:
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key)
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature

View File

@@ -81,7 +81,7 @@ class _matmul(torch.autograd.Function):
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 64 // block
start_width = 128 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
@@ -126,22 +126,18 @@ class _matmul(torch.autograd.Function):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
F32TK = [8, 16]
#F16TK = [16]
#F16TK += [32] if is_32_multiple else []
#F16TK += [64] if is_64_multiple else []
F16TK = [64]
TK = {torch.float32: F32TK,
torch.float16: F16TK}[dtype]
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
'TK': TK, 'TYPE': dtype,
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
# create output
@@ -270,9 +266,9 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': TM, 'TN': block, 'TK': TK,
defines = {'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
@@ -283,7 +279,7 @@ class _matmul(torch.autograd.Function):
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
@@ -315,9 +311,9 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': block, 'TN': TN, 'TK': TK,
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
@@ -328,7 +324,7 @@ class _matmul(torch.autograd.Function):
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0

View File

@@ -48,7 +48,7 @@ class _softmax(torch.autograd.Function):
# just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
if key not in cache:
defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block,
'INFINITY': {torch.float32: 'F32_INFINITY',
torch.float16: 'F16_INFINITY'}[dtype]}
if apply_scale:
@@ -63,7 +63,7 @@ class _softmax(torch.autograd.Function):
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps])
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel
return cache[key]

View File

@@ -29,10 +29,10 @@ class _conv(torch.autograd.Function):
TK = 16
defines = {
'TYPE' : dtype,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'TM' : 64,
'TN' : 64,
'TK' : TK,
'TZ' : 1,
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
@@ -40,7 +40,7 @@ class _conv(torch.autograd.Function):
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)

View File

@@ -83,8 +83,8 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)

View File

@@ -5,11 +5,21 @@ import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
TM = [128]
TN = [128]
TK = [32]
TZ = 1
num_warps = [4]
_DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
@@ -41,7 +51,7 @@ class _matmul(torch.autograd.Function):
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 32 == 0
is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels:
defines = {
@@ -53,13 +63,10 @@ class _matmul(torch.autograd.Function):
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'TM' : _matmul.TM,
'TN' : _matmul.TN,
'TK' : _matmul.TK,
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : int(is_tk_div_k)
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K'])
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines,
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K'])
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
@@ -68,7 +75,7 @@ class _matmul(torch.autograd.Function):
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
kernel(*args, grid=grid)
return c