diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 215003447..eff21fdfd 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -29,16 +29,16 @@ matmul_data = { (1024, 1024, 1024 ) : {'v100': 0.466}, (2048, 2048, 2048 ) : {'v100': 0.680}, (4096, 4096, 4096 ) : {'v100': 0.831}, - (8192, 8192, 8192 ) : {'v100': 0.841}, + (8192, 8192, 8192 ) : {'v100': 0.849}, # tall-skinny (16 , 1024, 1024 ) : {'v100': 0.0128}, - (16 , 4096, 4096 ) : {'v100': 0.0558}, + (16 , 4096, 4096 ) : {'v100': 0.0883}, (16 , 8192, 8192 ) : {'v100': 0.101}, - (64 , 1024, 1024 ) : {'v100': 0.049}, - (64 , 4096, 4096 ) : {'v100': 0.211}, + (64 , 1024, 1024 ) : {'v100': 0.073}, + (64 , 4096, 4096 ) : {'v100': 0.228}, (64 , 8192, 8192 ) : {'v100': 0.360}, - (1024, 64 , 1024 ) : {'v100': 0.0469}, - (4096, 64 , 4096 ) : {'v100': 0.198}, + (1024, 64 , 1024 ) : {'v100': 0.0692}, + (4096, 64 , 4096 ) : {'v100': 0.223}, (8192, 64 , 8192 ) : {'v100': 0.323}, # # deep reductions # (64 , 64 , 16384) : {'v100': 0.}, @@ -56,7 +56,7 @@ def test_matmul(M, N, K): a = torch.randn((M, K), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda') fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) cur_gpu_perf = 2.*M*N*K/ms * 1e-9 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) @@ -101,7 +101,7 @@ def test_elementwise(N): y = torch.randn_like(z) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 0751d044d..75241c291 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} - configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)] + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators kernel.kernel_decorators = [] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index efcc2701f..84d77795c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -574,7 +574,7 @@ class Kernel: prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self.fn into code-generator object - gscope = sys.modules[self.fn.module].__dict__ + gscope = self.fn.fn.__globals__ generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) @@ -698,6 +698,7 @@ class Autotuner: for i in self.reset_idx: args[i].zero_() self.hook = _hook + self.arg_names = arg_names def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -711,11 +712,14 @@ class Autotuner: # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) self.hook(args) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) def __call__(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: @@ -726,6 +730,8 @@ class Autotuner: config = self.cache[key] else: config = self.configs[0] + if config.pre_hook != None: + config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @@ -893,11 +899,14 @@ class Config: :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. Mostly useful for matrix multiplication workloads on SM80+ GPUs. :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2): + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_stages = num_stages + self.pre_hook = pre_hook def autotune(configs, key, reset_to_zero=None): @@ -963,7 +972,7 @@ def heuristics(values): def fun(*args, **meta): for v, heur in values.items(): assert v not in meta - meta[v] = heur(*args, **meta) + meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) return kernel(*args, **meta) return fun diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 49497777a..9c3317fe0 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -12,7 +12,7 @@ import torch # ******************************************************** @triton.heuristics({ - 'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0, + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @triton.jit def _sdd_kernel( diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 5b9d752ec..dcf77afc8 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -11,8 +11,8 @@ def num_warps(n): return 16 -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) +@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])}) @triton.jit def _forward( X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, @@ -71,8 +71,8 @@ def _forward( tl.store(px, x, mask=check) -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) +@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']}) @triton.jit def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): pidhm = tl.program_id(0) diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 8711d5b19..529b6c675 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -23,8 +23,8 @@ def num_warps(N): return 16 -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) -@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) @triton.jit def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) @@ -48,8 +48,8 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): tl.store(LOSS + row, probs) -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) -@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) @triton.jit def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 802908657..ae404b8d6 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -2,9 +2,11 @@ import torch import triton.language as tl import triton +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() @triton.heuristics({ - 'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0, + 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @triton.autotune( configs=[ @@ -18,6 +20,14 @@ import triton triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), ], key=['M', 'N', 'K'], ) @@ -26,7 +36,6 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - LOCKS, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication @@ -70,18 +79,7 @@ def _kernel(A, B, C, M, N, K, if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: - LOCKS = LOCKS + tl.program_id(0) - COUNT = LOCKS + tl.num_programs(0) - while tl.atomic_cas(LOCKS, 0, 1) == 1: - pass - count = tl.load(COUNT) - if count == 0: - tl.store(C, acc, mask=mask) - else: - curr = tl.load(C, mask=mask, other=0.) - tl.store(C, acc + curr, mask=mask) - tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K) - tl.atomic_xchg(LOCKS, 0) + tl.atomic_add(C, acc, mask=mask) class _matmul(torch.autograd.Function): @@ -103,17 +101,13 @@ class _matmul(torch.autograd.Function): _, N = b.shape # allocates output c = torch.empty((M, N), device=device, dtype=a.dtype) - # allocate locks for split-k - if a.device not in _matmul._locks: - _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) - locks = _matmul._locks[device] # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - locks, GROUP_M=8) + GROUP_M=8) return c @staticmethod