From 2acaa4d0dd61b4936a62327a144af536715bf96a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 30 Oct 2021 00:32:58 -0700 Subject: [PATCH] [LANG] Added support for constexpr (#361) --- lib/codegen/analysis/layout.cc | 3 - lib/codegen/transform/coalesce.cc | 3 +- lib/driver/llvm.cc | 2 +- python/test/regression/test_performance.py | 6 +- python/test/unit/language/test_core.py | 73 ++-- .../test/unit/operators/test_blocksparse.py | 6 +- python/test/unit/operators/test_matmul.py | 4 +- python/triton/code_gen.py | 107 ++++-- python/triton/language/core.py | 93 ++++- python/triton/ops/blocksparse/matmul.py | 347 ++++++------------ python/triton/ops/blocksparse/softmax.py | 26 +- python/triton/ops/cross_entropy.py | 6 +- python/triton/ops/matmul.py | 19 +- python/tutorials/01-vector-add.py | 9 +- python/tutorials/02-fused-softmax.py | 4 +- python/tutorials/03-matrix-multiplication.py | 12 +- 16 files changed, 355 insertions(+), 365 deletions(-) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 6ea0dd219..64163c91c 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -198,8 +198,6 @@ scanline_layout::scanline_layout(size_t num_warps, bool is_dot = std::any_of(values.begin(), values.end(), [&](ir::value* v) { return dynamic_cast(v); }); - - std::vector ptrs; for(ir::value *v: values) for(ir::user *usr: v->get_users()) @@ -215,7 +213,6 @@ scanline_layout::scanline_layout(size_t num_warps, contiguous = std::max(contiguous, std::min(align->get(ptr, i), 128 / nbits)); } - nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); size /= shape_[i]; diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 448517408..ae8ce034d 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -77,7 +77,6 @@ void coalesce::run(ir::module &mod) { builder.insert(new_x); x->replace_all_uses_with(new_x); new_x->replace_uses_of_with(new_x, x); -// new_x->replace_uses_of_with(new_x, new_x); } } for(ir::function *fn: mod.get_function_list()) @@ -101,6 +100,8 @@ void coalesce::run(ir::module &mod) { ir::instruction* curr = queue.back(); seen.insert(curr); queue.pop_back(); + if(auto dot_inst = dynamic_cast(curr)) + break; if(auto io_inst = dynamic_cast(curr)){ in_contig = align_->contiguous(io_inst->get_pointer_operand()); break; diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 3c11fbf35..f3c76ce77 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { ofs.close(); std::string cmd; int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o"; + cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index e205828d6..215003447 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -67,8 +67,8 @@ def test_matmul(M, N, K): import triton.language as tl @triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta): - BLOCK_SIZE = meta['BLOCK_SIZE'] +def _add(x_ptr, y_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -99,7 +99,7 @@ def test_elementwise(N): z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) - grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 25054f0dc..9354ec233 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -40,7 +40,7 @@ def patch_kernel(template, to_replace): def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit - def kernel(X, **meta): + def kernel(X, SIZE: tl.constexpr): pass x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) kernel[(1, )](x, SIZE=SIZE, num_warps=4) @@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit - def kernel(Z, X, **meta): - off = tl.arange(0, meta['SIZE']) + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) x = tl.load(X + off) z = GENERATE_TEST_HERE tl.store(Z + off, z) @@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c SIZE = 128 # define the kernel / launch-grid @triton.jit - def kernel(Z, X, Y, **meta): - off = tl.arange(0, meta['SIZE']) + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) x = tl.load(X + off) y = tl.load(Y + off) z = GENERATE_TEST_HERE @@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'): # Triton kernel @triton.jit - def kernel(Z, X, **meta): - SIZE = meta['SIZE'] + def kernel(Z, X, SIZE: tl.constexpr): m = tl.arange(0, SIZE) n = tl.arange(0, SIZE) x = tl.load(X_PTR_EXPR) @@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): + def kernel(X, Z): pid = tl.program_id(0) x = tl.load(X + pid) old = GENERATE_TEST_HERE @@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): + def kernel(X, Z, BITCAST: tl.constexpr): x = tl.load(X) - z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST']) + z = x.to(Z.dtype.element_ty, bitcast = BITCAST) tl.store(Z, z) # triton result @@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): - x = tl.load(X + tl.arange(0, meta['BLOCK'])) + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) tl.store(Z, tl.sum(x, axis=0)) x = triton.testing.random((shape,), dtype=dtype, device=device) @@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): dtype = cvt[dtype] # triton kernel @triton.jit - def kernel(X, Z, **meta): - range_m = tl.arange(0, meta['BLOCK_M']) - range_n = tl.arange(0, meta['BLOCK_N']) - x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :]) - z = tl.sum(x, axis=meta['AXIS']) + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) + z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input x = triton.testing.random(shape, dtype=dtype, device=device) @@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'): # triton kernel @triton.jit def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, **meta): - BLOCK_M = meta['BLOCK_M'] - BLOCK_N = meta['BLOCK_N'] + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn @@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'): @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, - Z, stride_zm, stride_zn, **meta): - BLOCK_M = meta['BLOCK_M'] - BLOCK_K = meta['BLOCK_K'] - BLOCK_N = meta['BLOCK_N'] + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) @@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'): Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn z = tl.dot(tl.load(Xs), tl.load(Ys)) - if meta['ADD_MATRIX']: + if ADD_MATRIX: z += tl.load(Zs) - if meta['ADD_ROWS']: + if ADD_ROWS: ZRs = Z + off_m * stride_zm z += tl.load(ZRs)[:, None] - if meta['ADD_COLS']: + if ADD_COLS: ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] tl.store(Zs, z) @@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'): def test_dot_without_load(): @triton.jit - def kernel(out, **meta): + def kernel(out): pid = tl.program_id(axis=0) a = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32) @@ -538,9 +535,10 @@ def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, **meta): - off = tl.arange(0, meta['BLOCK']) - val = tl.arange(meta['START'], meta['END']) + def _kernel(z, BLOCK: tl.constexpr, + START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) tl.store(z + off, val) _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) @@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): @triton.jit def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, - in_numel, in2_numel, out_numel, **meta): - M = meta['M'] - N = meta['N'] - K = meta['K'] + in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): M_offsets = tl.arange(0, M) N_offsets = tl.arange(0, N) @@ -605,14 +601,13 @@ def test_load_cache_modifier(cache): dst = torch.empty(128, device='cuda') @triton.jit - def _kernel(dst, src, **meta): + def _kernel(dst, src, CACHE: tl.constexpr): offsets = tl.arange(0, 128) - x = tl.load(src+offsets, cache_modifier=meta['CACHE']) + x = tl.load(src+offsets, cache_modifier=CACHE) tl.store(dst+offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] - if cache == '': assert 'ld.global.ca' not in ptx assert 'ld.global.cg' not in ptx @@ -644,7 +639,7 @@ def test_load_cache_modifier(cache): #---------------- def test_noop(device='cuda'): @triton.jit - def kernel(**meta): + def kernel(x): pass x = triton.testing.random((1,), dtype=torch.int32, device=device) kernel[(1, )](x) \ No newline at end of file diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 3b9a1c17f..b9cdc23c7 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) # triton result - op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) @@ -151,8 +151,8 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False) + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) sparse_softmax = triton.ops.blocksparse.softmax( layout, block, diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index dbf1974ce..0751d044d 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -66,8 +66,8 @@ import torch def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually - META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} - configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)] + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators kernel.kernel_decorators = [] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8622333bf..f14f3b135 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor): arg_values = [] for i, arg_name in enumerate(arg_names): if i in self.constants: - cst = triton.language.core._to_ir(self.constants[i], self.builder) + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) else: if i in self.attributes: @@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor): fn.add_attr(i + 1, attr) fn.args[i].name = arg_name arg_values.append(fn.args[i]) + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: @@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor): ast.NodeVisitor.generic_visit(self, node) return node.arg + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == triton.language.constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + self.lscope[target] = triton.language.constexpr(value) + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): _names = [] for target in node.targets: @@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + if isinstance(value, triton.language.constexpr): + value = value.value if not isinstance(value, triton.language.block): value = triton.language.core._to_ir(value, self.builder) self.set_value(name, value) @@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) + if isinstance(lhs, triton.language.core.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.core.constexpr): + rhs = rhs.value fn = { ast.Add: '__add__', ast.Sub: '__sub__', @@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - kws = dict() - if self.is_triton_object(lhs): - kws['_builder'] = self.builder - ret = getattr(lhs, fn)(rhs, **kws) - if ret is NotImplemented: - if self.is_triton_object(rhs): - kws['_builder'] = self.builder + return getattr(lhs, fn)(rhs, _builder=self.builder) + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] - ret = getattr(rhs, fn)(lhs, **kws) - return ret + return getattr(rhs, fn)(lhs, _builder=self.builder) + else: + return getattr(lhs, fn)(rhs) def visit_If(self, node): cond = self.visit(node.test) @@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) + if isinstance(lhs, triton.language.core.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.core.constexpr): + rhs = rhs.value fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_UnaryOp(self, node): op = self.visit(node.operand) + if isinstance(op, triton.language.core.constexpr): + op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor): return fn(*args, **kws) def visit_Num(self, node): - return node.n + return triton.language.constexpr(node.n) def visit_Attribute(self, node): lhs = self.visit(node.value) @@ -477,6 +505,8 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] + if isinstance(obj, triton.language.core.constexpr): + obj = obj.value if isinstance(obj, int): if abs(obj) <= 0xffffffff: return 'I' @@ -485,6 +515,8 @@ class Kernel: return 'f' if isinstance(obj, bool): return 'B' + if isinstance(obj, str): + return 'str' assert False @@ -537,7 +569,8 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): + def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): + wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -547,7 +580,7 @@ class Kernel: # generate Triton-IR # export symbols visible from self.fn into code-generator object gscope = sys.modules[self.fn.module].__dict__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) except Exception as e: @@ -566,7 +599,19 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + # handle arguments passed by name + kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} + wargs = list(wargs) + for i, pos in enumerate(sorted(kwargs)): + wargs.insert(pos + i, kwargs[pos]) + if len(wargs) != len(self.fn.arg_names): + raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") + # handle annotations + for name, type in self.fn.__annotations__.items(): + pos = self.fn.arg_names.index(name) + assert type == triton.language.core.constexpr + wargs[pos] = type(wargs[pos]) # device inference tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: @@ -601,18 +646,19 @@ class Kernel: args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ if isinstance(a, int) and i not in self.fn.do_not_specialize} + # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + # compute hash for caching this kernel types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) attr_key = tuple(attributes.items()) - meta_key = tuple(sorted(meta.items())) const_key = tuple(constants.items()) compute_capability = torch.cuda.get_device_capability(device) - key = ( self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, meta_key, const_key + types_key, attr_key, num_warps, num_stages, const_key ) key = repr(key) @@ -644,7 +690,7 @@ class Kernel: binary = self._compile( *wargs, device=device_idx, attributes=attributes, num_warps=num_warps, num_stages=num_stages, - constants=constants, **meta + constants=constants, ) if bin_cache_path: assert bin_lock_path is not None @@ -657,12 +703,15 @@ class Kernel: drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)]) - params = struct.pack(fmt, *args) + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) + params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) # enqueue cached function into stream callable = drv_cache[key] stream = torch.cuda.current_stream(device_idx).cuda_stream - grid = grid(meta) if hasattr(grid, '__call__') else grid + csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} + grid = grid(csts) if hasattr(grid, '__call__') else grid + if isinstance(grid, int): + grid = tuple(grid) callable(stream, params, *grid) return callable @@ -697,31 +746,31 @@ class Autotuner: def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner - conflicts = meta.keys() & config.meta.keys() + conflicts = meta.keys() & config.kwargs.keys() if conflicts: raise ValueError( f"Conflicting meta-parameters: {', '.join(conflicts)}." " Make sure that you don't re-define auto-tuned symbols." ) # augment meta-parameters with tunable ones - current = dict(meta, **config.meta) + current = dict(meta, **config.kwargs) def kernel_call(): self.hook(args) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) - def __call__(self, *args, **meta): + def __call__(self, *args, **kwargs): if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: - timings = {config: self._bench(*args, config=config, **meta) \ + timings = {config: self._bench(*args, config=config, **kwargs) \ for config in self.configs} self.cache[key] = builtins.min(timings, key=timings.get) self.hook(args) config = self.cache[key] else: config = self.configs[0] - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) + return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @functools.lru_cache() @@ -769,6 +818,8 @@ class JITFunction: # when called with a grid using __getitem__ self.kernel_decorators = [] self.kernel = None + # annotations + self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ @@ -839,8 +890,8 @@ class Config: Mostly useful for matrix multiplication workloads on SM80+ GPUs. :type num_stages: int """ - def __init__(self, meta, num_warps=4, num_stages=2): - self.meta = meta + def __init__(self, kwargs, num_warps=4, num_stages=2): + self.kwargs = kwargs self.num_warps = num_warps self.num_stages = num_stages diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e584aecb1..5eed3b67f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -14,9 +14,11 @@ def _to_ir(x, builder): return builder.get_int64(x) elif isinstance(x, float): return builder.get_float32(x) - if isinstance(x, block): + elif isinstance(x, constexpr): + return _to_ir(x.value, builder) + elif isinstance(x, block): return x.handle - if isinstance(x, dtype): + elif isinstance(x, dtype): return x.handle(builder) return x @@ -257,6 +259,86 @@ class block: return frontend.cast(self, dtype, _builder) +# ----------------------- +# constexpr +# ----------------------- + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + self.value = value + + def __add__(self, other): + return self.value + other.value + + def __radd__(self, other): + return other.value + self.value + + def __sub__(self, other): + return self.value - other.value + + def __rsub__(self, other): + return other.value - self.value + + def __mul__(self, other): + return self.value * other.value + + def __rmul__(self, other): + return other.value * self.value + + def __truediv__(self, other): + return self.value / other.value + + def __rtruediv__(self, other): + return other.value / self.value + + def __floordiv__(self, other): + return self.value // other.value + + def __rfloordiv__(self, other): + return other.value // self.value + + # + + def __gt__(self, other): + return self.value > other.value + + def __rgt__(self, other): + return other.value > self.value + + def __ge__(self, other): + return self.value >= other.value + + def __rge__(self, other): + return other.value >= self.value + + def __lt__(self, other): + return self.value < other.value + + def __rlt__(self, other): + return other.value < self.value + + def __le__(self, other): + return self.value <= other.value + + def __rle__(self, other): + return other.value <= self.value + + def __eq__(self, other): + return self.value == other.value + + def __ne__(self, other): + return self.value != other.value + + def __bool__(self): + return bool(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + # ----------------------- # SPMD Programming Model # ----------------------- @@ -312,7 +394,12 @@ def zeros(shape, dtype, _builder=None): :param dtype: Data-type of the new array, e.g., :code:`tl.float16` :type dtype: DType """ - shape = [int(x.handle) if isinstance(x, block) else x for x in shape] + for i, d in enumerate(shape): + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] return frontend.zeros(shape, dtype, _builder) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index bcba2e505..8a020e5c2 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,6 +1,5 @@ import triton import triton.language as tl -import triton._C.libtriton as libtriton import torch # ******************************************************** @@ -21,54 +20,46 @@ def _sdd_kernel( stride_za, stride_ha, stride_ma, stride_ak, stride_zb, stride_hb, stride_bk, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, **meta + K, grid_offset, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# - pid1 = tl.program_id(1) + grid_offset - blockidm = tl.arange(0, TILE_M) // BLOCK - blockidn = tl.arange(0, TILE_N) // BLOCK - offlutm = blockidm * (TILE_N // BLOCK) * 4 - offlutn = blockidn * 4 - header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4 - # batch offset - off_z = tl.program_id(2) - # head offset - off_h = tl.load(header + 0) + block_id = tl.program_id(1) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + # initialize pointers to A - start_am = tl.load(header + 1 + offlutm) + start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + off_z * stride_za \ + a_ptrs = A + (off_z * stride_za \ + off_h * stride_ha \ + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ak + + offs_ak[None, :] * stride_ak) # initialize pointers to B - start_bn = tl.load(header + 2 + offlutn) + start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + off_z * stride_zb \ + b_ptrs = B + (off_z * stride_zb \ + off_h * stride_hb \ + offs_bn[None, :] * stride_nb \ - + offs_bk[:, None] * stride_bk + + offs_bk[:, None] * stride_bk) ## ---------------- ## ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(K, 0, -TILE_K): - if meta['EVEN_K']: + if EVEN_K: a = tl.load(a_ptrs) b = tl.load(b_ptrs) else: a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) acc += tl.dot(a, b) a_ptrs += TILE_K * stride_ak b_ptrs += TILE_K * stride_bk @@ -76,22 +67,15 @@ def _sdd_kernel( ## ---------------- ## ## Epilogue ## ## ---------------- ## - blockidm = tl.arange(0, TILE_M) // BLOCK - blockidn = tl.arange(0, TILE_N) // BLOCK - offlutm = blockidm * (TILE_N // BLOCK) * 4 - offlutn = blockidn * 4 - off_block_id = 3 + offlutm[:, None] + offlutn[None, :] - block_id = tl.load(header + off_block_id) - # initialize pointers to C offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + off_z * stride_zc \ + pc = C + (off_z * stride_zc \ + block_id * stride_hc \ + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc + + offs_cn[None, :] * stride_nc) tl.store(pc, c, mask=True) -def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): # (A * B)^T = B^T * A^T if trans_c: a, b = b, a @@ -102,46 +86,28 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, Ka, Kb = a.shape[a_dim], b.shape[b_dim] if Ka != Kb: raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") - if Ka % 16 != 0: - raise ValueError('Reduction size for SDD must be a multiple of 16') # allocate output - n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)]) - c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device) - # each iteration of the loop below - # computes the value for one group of super-blocks - # (e.g., all 4x4 super-blocks) - for lut, width, pack in zip(luts, widths, packs): - # maximum grid size in Triton/CUDA is 64k but we may have more - # super-blocks than that. - max_grid = 65535 - for off_grid in range(0, width, max_grid): - grid = [1, min(max_grid, width - off_grid), c.shape[0]] - # fmt: off - pgm = _sdd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - Ka, off_grid, lut, - TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3, - num_warps=4, - ) - # print(pgm.asm['ptx']) - # exit() + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [1, c.shape[1], c.shape[0]] + _sdd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(2), c.stride(3), + Ka, 0, lut, + TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, + num_warps=4, + ) return c + def sdd_lut(layout, block, device): - start_width = 128 // block - layout = layout.type(torch.int32) - superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) - luts, widths, packs = [], [], [] - for size, nnz in superblocks: - nnz = nnz.reshape(-1, 4) - width = nnz.shape[0] // (size * size) - luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) - widths.append(width) - packs.append(size) - return luts, None, widths, packs + lut = layout.nonzero(as_tuple=False).to(device).int() + return lut, None # ----------------------------- # Dense = Sparse x Dense (DSD) @@ -154,12 +120,10 @@ def _dsd_kernel( stride_az, stride_ha, stride_am, stride_ak, stride_zb, stride_hb, stride_bk, stride_bn, stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, **meta + DS0, DS1, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - GROUP_SIZE_M = meta['GROUP_SIZE_M'] #------------# #- Prologue -# #------------# @@ -167,9 +131,9 @@ def _dsd_kernel( pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pidz = tl.program_id(2) - header = lut + pid_m * 4 + header = lut + pid_n * 4 offset = tl.load(header + 0) K = tl.load(header + 1) column = tl.load(header + 2) @@ -185,7 +149,8 @@ def _dsd_kernel( + offs_am[:, None] * stride_am \ + offs_ak[None, :] * stride_ak # initialize pointers to B (dense) - offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N) + offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) start_bk = tl.load(pinc) start_bk = tl.multiple_of(start_bk, 8) # compiler hint offs_bk = start_bk + tl.arange(0, TILE_K) @@ -197,28 +162,33 @@ def _dsd_kernel( ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) for k in range(K, 0, -TILE_K): a = tl.load(pa, mask=True) b = tl.load(pb, mask=offs_bn[None, :] < DS0) acc += tl.dot(a, b) + pa += inc_a + pb += inc_b*stride_bk pinc += 2 inc_a = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) inc_b = tl.load(pinc) inc_b = tl.multiple_of(inc_b, 8) - pa += inc_a - pb += inc_b*stride_bk c = acc.to(C.dtype.element_ty) # initialize pointers to C offs_cm = column*TILE_M + tl.arange(0, TILE_M) - offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N) + offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) pc = C + off_h * stride_hc \ + pidz * stride_zc \ + offs_cm[:, None] * stride_cm \ + offs_cn[None, :] * stride_cn tl.store(pc, c, mask = offs_cn[None, :] < DS0) -def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): # shapes / dtypes AS1 = block * spdims[2 if trans_a else 1] BS0 = b.size(0) @@ -230,11 +200,15 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out # meta-parameter heuristics - TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block] + TILE_N = 128 # compute output - grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0] + grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] # fmt: off _dsd_kernel[grid]( a, b, c, @@ -242,8 +216,8 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), BS3, AS1, lut, - TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3, - num_warps=4, GROUP_SIZE_M=8, + TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + num_warps=4, GROUP_SIZE_M=4, ) # exit() return c @@ -323,7 +297,7 @@ def dsd_lut(layout, block, step, trans, device): lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device) # create locks - return lut, None, width, None + return lut, width # ----------------------------- # Dense = Dense x Sparse (DDS) @@ -334,12 +308,10 @@ def _dds_kernel( stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_bk, stride_bn, stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, **meta + DS0, DS1, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - GROUP_SIZE_M = meta['GROUP_SIZE_M'] #------------# #- Prologue -# #------------# @@ -347,16 +319,17 @@ def _dds_kernel( pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pid_z = tl.program_id(2) - header = lut + pid_m * 4 + header = lut + pid_n * 4 offset = tl.load(header + 0) AS1 = tl.load(header + 1) column = tl.load(header + 2) off_h = tl.load(header + 3) pinc = lut + offset # initialize pointers to A (dense) - offs_am = pid_n*TILE_M + tl.arange(0, TILE_M) + offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) start_ak = tl.load(pinc) start_ak = tl.multiple_of(start_ak, 8) offs_ak = start_ak + tl.arange(0, TILE_K) @@ -394,7 +367,7 @@ def _dds_kernel( ## ---------------- ## c = acc.to(C.dtype.element_ty) # initialize pointers to C (dense) - offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M) + offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cn = column * TILE_N + tl.arange(0, TILE_N) ptrs_c = C + off_h * stride_hc \ + pid_z * stride_zc \ @@ -403,7 +376,7 @@ def _dds_kernel( # write back tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) -def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) @@ -415,9 +388,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w CS1 = AS1 CS2 = BS2 if trans_c else AS2 CS3 = AS2 if trans_c else BS2 - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] - grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0] + grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] # fmt: off _dds_kernel[grid]( a, b, c, @@ -425,8 +402,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), AS2, BS2, lut, - TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3, - num_warps=4, GROUP_SIZE_M=8, + TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + num_warps=4, GROUP_SIZE_M=4, ) return c @@ -439,25 +416,23 @@ class _matmul(torch.autograd.Function): @staticmethod def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks, - da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, + c_lut, c_width, da_lut, da_width, db_lut, db_width, out ): - c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs) + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) # save for backward ctx.save_for_backward(a, b) - ctx.da_num_locks = da_num_locks ctx.da_lut = da_lut ctx.da_width = da_width - ctx.da_packs = da_packs ctx.db_lut = db_lut - ctx.db_num_locks = db_num_locks ctx.db_width = db_width - ctx.db_packs = db_packs ctx.mode = mode ctx.spdims = spdims ctx.block = block ctx.trans_a = trans_a ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None return c @staticmethod @@ -466,155 +441,55 @@ class _matmul(torch.autograd.Function): a, b = ctx.saved_tensors da, db = None, None mode = ctx.mode - # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2] da = _matmul.fn[mode_da]( - dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width, - ctx.da_packs + dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width, ) # gradients w.r.t. b if ctx.needs_input_grad[1]: mode_db = mode[2] + mode[1] + mode[0] db = _matmul.fn[mode_db]( - a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width, - ctx.db_packs + a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, ) + dout = dc if ctx.has_out else None return da, db, None, None, None,\ None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None - + None, None, None, None, None, dout class matmul: - def make_lut(self, dtype, device): - key = (dtype, device) - if key in self.lut_cache: - return self.lut_cache[key] - # C look-up table - layout, block = self.layout, self.block - step = min(block, 32) - if self.mode == 'sdd': - c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device) - elif self.mode == 'dsd': - c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device) - elif self.mode == 'dds': - c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device) - # DA look-up table - if self.mode == 'sdd': - da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device) - elif self.mode == 'dsd': - da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device) - elif self.mode == 'dds': - da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device) - # DB look-up table - if self.mode == 'sdd': - db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device) - elif self.mode == 'dsd': - db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device) - elif self.mode == 'dds': - db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device) - self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs, - da_lut, da_num_locks, da_width, da_packs, - db_lut, db_num_locks, db_width, db_packs) - return self.lut_cache[key] - def __init__(self, layout, block, mode, trans_a=False, trans_b=False): + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): if mode not in ['sdd', 'dsd', 'dds']: raise NotImplementedError('Supported modes are: sdd, dsd, dds') - # look-up table cache - self.lut_cache = dict() - # attributes self.block = block self.mode = mode self.trans_a = trans_a self.trans_b = trans_b - - layout_dim = layout.ndim - assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s" - - if not mode == 'sdd': - # Dims to be reduced on the 'inside' of the matmul, either -1 or -2 - trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2) - self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner - sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1) - - # Inner dim of the dense input should be equal to the inner dim of the sparse input - self.dense_inner_size = layout.shape[sparse_inner] * block - # Expected shape for sparse inputs - self.sparse_shape = (layout.sum().item(), block, block) - - # Support using the same layout across attention heads etc. - if layout_dim == 2: - layout = layout.unsqueeze(0) - - layout = layout.long() # Above code assumes the layout tensor is an integral type + self.trans_c = trans_c self.layout = layout self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) - def __call__(self, a, b): - c_lut, c_num_locks, c_width, c_packs,\ - da_lut, da_num_locks, da_width, da_packs,\ - db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) - - # If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior - # and potential illegal memory accesses - original_dims = max(a.ndim, b.ndim) - a, b = self._validate_inputs(a, b) - - # execute + def __call__(self, a, b, out = None): c = _matmul.apply( - a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, - c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, + self.c_lut, self.c_width, + self.da_lut, self.da_width, + self.db_lut, self.db_width, + out ) - # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input - dims_to_trim = c.ndim - original_dims - for _ in range(dims_to_trim): - c = c.squeeze(0) - - return c - - def _validate_inputs(self, a, b): - if a.device != b.device: - raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A " - f"and {b.device} for tensor B") - if not a.is_cuda: - raise ValueError("Only GPU devices are supported for now") - - # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here - if torch.is_autocast_enabled(): - a, b = a.half(), b.half() - elif a.dtype != b.dtype: - raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B") - - mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b - if mode != 'sdd': - # One input is sparse - dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A') - dense_inner = dense.shape[self.dense_inner_dim] - if dense_inner != self.dense_inner_size: - raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim " - f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.") - - if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape: - raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument " - f"{sparse_name}, got {sparse.shape}") - - def add_extra_dims(x): - # Add extra leading singleton dimensions if needed - dims_needed = 4 - x.ndim - if dims_needed > 0: - singletons = [1] * dims_needed - x = x.view(*singletons, *x.shape) - elif dims_needed < 0: - raise ValueError("Tensors with more than 4 dimensions are not currently supported") - - return x - - # Pad shapes with leading singleton dimensions - a = add_extra_dims(a) - b = add_extra_dims(b) - - return a, b + return c \ No newline at end of file diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 10f806af2..5b9d752ec 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -16,10 +16,9 @@ def num_warps(n): @triton.jit def _forward( X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - **meta + TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr, + KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr, ): - TN = meta['TN'] - BLOCK = meta['BLOCK'] pidhm = tl.program_id(0) pidz = tl.program_id(1) # create index ranges @@ -43,25 +42,25 @@ def _forward( x = tl.load(px, mask=check, other=-float('inf')) x = x.to(tl.float32) # apply scale - if meta['APPLY_SCALE']: + if APPLY_SCALE: x = x * scale # apply RPE - if meta['APPLY_RPE']: + if APPLY_RPE: prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn rpe = tl.load(prpe, mask=check, other=0) x = x + rpe # apply key-padding mask - if meta['APPLY_KP_MASK']: + if APPLY_KP_MASK: pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) - if meta['KP_MASK_MUL']: + if KP_MASK_MUL: kp_m = tl.where(kp_m == 0, -float('inf'), 0.) x = x + kp_m # apply attention mask - if meta['APPLY_ATTN_MASK']: + if APPLY_ATTN_MASK: pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) - if meta['ATTN_MASK_MUL']: + if ATTN_MASK_MUL: attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # apply causal mask @@ -75,11 +74,9 @@ def _forward( @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) @triton.jit -def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): +def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): pidhm = tl.program_id(0) pidz = tl.program_id(1) - TN = meta['TN'] - BLOCK = meta['BLOCK'] # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK @@ -172,8 +169,7 @@ class _softmax(torch.autograd.Function): APPLY_KP_MASK = apply_kp_mask, APPLY_ATTN_MASK = apply_attn_mask, KP_MASK_MUL = (kp_mask_mode == 'mul'), - ATTN_MASK_MUL = (attn_mask_mode == 'mul'), - force_nc_cache = True) + ATTN_MASK_MUL = (attn_mask_mode == 'mul')) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -196,7 +192,7 @@ class _softmax(torch.autograd.Function): # run kernel M = x.shape[0] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) + _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 87833e3c0..8711d5b19 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -26,8 +26,7 @@ def num_warps(N): @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) @triton.jit -def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): - BLOCK = meta['BLOCK'] +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) @@ -52,8 +51,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) @triton.jit -def _backward(PROBS, IDX, DPROBS, N, **meta): - BLOCK = meta['BLOCK'] +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 22f5f6cc2..802908657 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -26,13 +26,9 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - LOCKS, **META): - # extract meta-parameters - BLOCK_M = META['BLOCK_M'] - BLOCK_N = META['BLOCK_N'] - BLOCK_K = META['BLOCK_K'] - GROUP_M = META['GROUP_M'] - SPLIT_K = META['SPLIT_K'] + LOCKS, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -55,7 +51,7 @@ def _kernel(A, B, C, M, N, K, B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K*SPLIT_K): - if META['EVEN_K']: + if EVEN_K: a = tl.load(A) b = tl.load(B) else: @@ -113,14 +109,11 @@ class _matmul(torch.autograd.Function): locks = _matmul._locks[device] # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, - M, N, K, + _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - locks, - GROUP_M=8) - # done + locks, GROUP_M=8) return c @staticmethod diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 4446cf6e9..0934c8ea1 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn # Compute Kernel # -------------------------- +from triton.language.core import constexpr import torch import triton import triton.language as tl @@ -23,9 +24,9 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector - **meta, # Optional meta-parameters for the kernel + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process + # NOTE: `constexpr` so it can be used as a shape value ): - BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -37,8 +38,8 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements - # Load x and y from DRAM, masking out any extar elements in case the input is not a - # multiple of the block size + # Load x and y from DRAM, masking out any extra elements in case + # the input is not a multiple of the block size x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 15efa7c81..2c0cfb9a8 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -65,11 +65,11 @@ import triton.language as tl @triton.jit def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta + output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, + BLOCK_SIZE: tl.constexpr ): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) - BLOCK_SIZE = meta['BLOCK_SIZE'] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 1d9fea638..2d2ab91e9 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -182,17 +182,13 @@ def matmul_kernel( stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters - **meta, -): + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, + ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ - # extract meta-parameters - BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] - BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] - BLOCK_SIZE_K = meta['BLOCK_SIZE_K'] - GROUP_SIZE_M = 8 - # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse