[RUNTIME] Config hook v2.0 (#373)

* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
This commit is contained in:
daadaada
2021-11-22 03:20:59 +08:00
committed by GitHub
parent 5693b582ea
commit 1296eb877b
7 changed files with 44 additions and 40 deletions

View File

@@ -29,16 +29,16 @@ matmul_data = {
(1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.841},
(8192, 8192, 8192 ) : {'v100': 0.849},
# tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0558},
(16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.049},
(64 , 4096, 4096 ) : {'v100': 0.211},
(64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.228},
(64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0469},
(4096, 64 , 4096 ) : {'v100': 0.198},
(1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.223},
(8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
@@ -56,7 +56,7 @@ def test_matmul(M, N, K):
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
@@ -101,7 +101,7 @@ def test_elementwise(N):
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []

View File

@@ -574,7 +574,7 @@ class Kernel:
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
gscope = self.fn.fn.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
@@ -698,6 +698,7 @@ class Autotuner:
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -711,11 +712,14 @@ class Autotuner:
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
@@ -726,6 +730,8 @@ class Autotuner:
config = self.cache[key]
else:
config = self.configs[0]
if config.pre_hook != None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@@ -893,11 +899,14 @@ class Config:
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2):
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def autotune(configs, key, reset_to_zero=None):
@@ -963,7 +972,7 @@ def heuristics(values):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur(*args, **meta)
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun

View File

@@ -12,7 +12,7 @@ import torch
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0,
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(

View File

@@ -11,8 +11,8 @@ def num_warps(n):
return 16
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])})
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
@@ -71,8 +71,8 @@ def _forward(
tl.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
pidhm = tl.program_id(0)

View File

@@ -23,8 +23,8 @@ def num_warps(N):
return 16
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
@@ -48,8 +48,8 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)

View File

@@ -2,9 +2,11 @@ import torch
import triton.language as tl
import triton
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.heuristics({
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
@@ -18,6 +20,14 @@ import triton
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
],
key=['M', 'N', 'K'],
)
@@ -26,7 +36,6 @@ def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
LOCKS,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
# matrix multiplication
@@ -70,18 +79,7 @@ def _kernel(A, B, C, M, N, K,
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
LOCKS = LOCKS + tl.program_id(0)
COUNT = LOCKS + tl.num_programs(0)
while tl.atomic_cas(LOCKS, 0, 1) == 1:
pass
count = tl.load(COUNT)
if count == 0:
tl.store(C, acc, mask=mask)
else:
curr = tl.load(C, mask=mask, other=0.)
tl.store(C, acc + curr, mask=mask)
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
tl.atomic_xchg(LOCKS, 0)
tl.atomic_add(C, acc, mask=mask)
class _matmul(torch.autograd.Function):
@@ -103,17 +101,13 @@ class _matmul(torch.autograd.Function):
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocate locks for split-k
if a.device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
locks, GROUP_M=8)
GROUP_M=8)
return c
@staticmethod