From 2d6df9b518a8152f777eb79b6b0a84becb706353 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Oct 2021 01:24:19 -0700 Subject: [PATCH 001/215] [PACKAGING] Bumped dev version to 1.1.2 --- .github/workflows/integration-tests.yml | 2 +- python/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ac0fe14f0..987b346a3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - master - - v2 + - v2.0 jobs: diff --git a/python/setup.py b/python/setup.py index 0d962355f..f77b92666 100644 --- a/python/setup.py +++ b/python/setup.py @@ -121,7 +121,7 @@ class CMakeBuild(build_ext): setup( name="triton", - version="1.1.1", + version="1.1.2", author="Philippe Tillet", author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", From 770ea96ccad02dac0d8805f60431bf58185a7877 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Oct 2021 01:28:17 -0700 Subject: [PATCH 002/215] [PACKAGING] Bumped dev version to 2.0.0 --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index f77b92666..21cabe182 100644 --- a/python/setup.py +++ b/python/setup.py @@ -121,7 +121,7 @@ class CMakeBuild(build_ext): setup( name="triton", - version="1.1.2", + version="2.0.0", author="Philippe Tillet", author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", From 2acaa4d0dd61b4936a62327a144af536715bf96a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 30 Oct 2021 00:32:58 -0700 Subject: [PATCH 003/215] [LANG] Added support for constexpr (#361) --- lib/codegen/analysis/layout.cc | 3 - lib/codegen/transform/coalesce.cc | 3 +- lib/driver/llvm.cc | 2 +- python/test/regression/test_performance.py | 6 +- python/test/unit/language/test_core.py | 73 ++-- .../test/unit/operators/test_blocksparse.py | 6 +- python/test/unit/operators/test_matmul.py | 4 +- python/triton/code_gen.py | 107 ++++-- python/triton/language/core.py | 93 ++++- python/triton/ops/blocksparse/matmul.py | 347 ++++++------------ python/triton/ops/blocksparse/softmax.py | 26 +- python/triton/ops/cross_entropy.py | 6 +- python/triton/ops/matmul.py | 19 +- python/tutorials/01-vector-add.py | 9 +- python/tutorials/02-fused-softmax.py | 4 +- python/tutorials/03-matrix-multiplication.py | 12 +- 16 files changed, 355 insertions(+), 365 deletions(-) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 6ea0dd219..64163c91c 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -198,8 +198,6 @@ scanline_layout::scanline_layout(size_t num_warps, bool is_dot = std::any_of(values.begin(), values.end(), [&](ir::value* v) { return dynamic_cast(v); }); - - std::vector ptrs; for(ir::value *v: values) for(ir::user *usr: v->get_users()) @@ -215,7 +213,6 @@ scanline_layout::scanline_layout(size_t num_warps, contiguous = std::max(contiguous, std::min(align->get(ptr, i), 128 / nbits)); } - nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); size /= shape_[i]; diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 448517408..ae8ce034d 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -77,7 +77,6 @@ void coalesce::run(ir::module &mod) { builder.insert(new_x); x->replace_all_uses_with(new_x); new_x->replace_uses_of_with(new_x, x); -// new_x->replace_uses_of_with(new_x, new_x); } } for(ir::function *fn: mod.get_function_list()) @@ -101,6 +100,8 @@ void coalesce::run(ir::module &mod) { ir::instruction* curr = queue.back(); seen.insert(curr); queue.pop_back(); + if(auto dot_inst = dynamic_cast(curr)) + break; if(auto io_inst = dynamic_cast(curr)){ in_contig = align_->contiguous(io_inst->get_pointer_operand()); break; diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 3c11fbf35..f3c76ce77 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { ofs.close(); std::string cmd; int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o"; + cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index e205828d6..215003447 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -67,8 +67,8 @@ def test_matmul(M, N, K): import triton.language as tl @triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta): - BLOCK_SIZE = meta['BLOCK_SIZE'] +def _add(x_ptr, y_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -99,7 +99,7 @@ def test_elementwise(N): z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) - grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 25054f0dc..9354ec233 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -40,7 +40,7 @@ def patch_kernel(template, to_replace): def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit - def kernel(X, **meta): + def kernel(X, SIZE: tl.constexpr): pass x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) kernel[(1, )](x, SIZE=SIZE, num_warps=4) @@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit - def kernel(Z, X, **meta): - off = tl.arange(0, meta['SIZE']) + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) x = tl.load(X + off) z = GENERATE_TEST_HERE tl.store(Z + off, z) @@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c SIZE = 128 # define the kernel / launch-grid @triton.jit - def kernel(Z, X, Y, **meta): - off = tl.arange(0, meta['SIZE']) + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) x = tl.load(X + off) y = tl.load(Y + off) z = GENERATE_TEST_HERE @@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'): # Triton kernel @triton.jit - def kernel(Z, X, **meta): - SIZE = meta['SIZE'] + def kernel(Z, X, SIZE: tl.constexpr): m = tl.arange(0, SIZE) n = tl.arange(0, SIZE) x = tl.load(X_PTR_EXPR) @@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): + def kernel(X, Z): pid = tl.program_id(0) x = tl.load(X + pid) old = GENERATE_TEST_HERE @@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): + def kernel(X, Z, BITCAST: tl.constexpr): x = tl.load(X) - z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST']) + z = x.to(Z.dtype.element_ty, bitcast = BITCAST) tl.store(Z, z) # triton result @@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'): # triton kernel @triton.jit - def kernel(X, Z, **meta): - x = tl.load(X + tl.arange(0, meta['BLOCK'])) + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) tl.store(Z, tl.sum(x, axis=0)) x = triton.testing.random((shape,), dtype=dtype, device=device) @@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): dtype = cvt[dtype] # triton kernel @triton.jit - def kernel(X, Z, **meta): - range_m = tl.arange(0, meta['BLOCK_M']) - range_n = tl.arange(0, meta['BLOCK_N']) - x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :]) - z = tl.sum(x, axis=meta['AXIS']) + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) + z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input x = triton.testing.random(shape, dtype=dtype, device=device) @@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'): # triton kernel @triton.jit def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, **meta): - BLOCK_M = meta['BLOCK_M'] - BLOCK_N = meta['BLOCK_N'] + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn @@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'): @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, - Z, stride_zm, stride_zn, **meta): - BLOCK_M = meta['BLOCK_M'] - BLOCK_K = meta['BLOCK_K'] - BLOCK_N = meta['BLOCK_N'] + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) @@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'): Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn z = tl.dot(tl.load(Xs), tl.load(Ys)) - if meta['ADD_MATRIX']: + if ADD_MATRIX: z += tl.load(Zs) - if meta['ADD_ROWS']: + if ADD_ROWS: ZRs = Z + off_m * stride_zm z += tl.load(ZRs)[:, None] - if meta['ADD_COLS']: + if ADD_COLS: ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] tl.store(Zs, z) @@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'): def test_dot_without_load(): @triton.jit - def kernel(out, **meta): + def kernel(out): pid = tl.program_id(axis=0) a = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32) @@ -538,9 +535,10 @@ def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, **meta): - off = tl.arange(0, meta['BLOCK']) - val = tl.arange(meta['START'], meta['END']) + def _kernel(z, BLOCK: tl.constexpr, + START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) tl.store(z + off, val) _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) @@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): @triton.jit def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, - in_numel, in2_numel, out_numel, **meta): - M = meta['M'] - N = meta['N'] - K = meta['K'] + in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): M_offsets = tl.arange(0, M) N_offsets = tl.arange(0, N) @@ -605,14 +601,13 @@ def test_load_cache_modifier(cache): dst = torch.empty(128, device='cuda') @triton.jit - def _kernel(dst, src, **meta): + def _kernel(dst, src, CACHE: tl.constexpr): offsets = tl.arange(0, 128) - x = tl.load(src+offsets, cache_modifier=meta['CACHE']) + x = tl.load(src+offsets, cache_modifier=CACHE) tl.store(dst+offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] - if cache == '': assert 'ld.global.ca' not in ptx assert 'ld.global.cg' not in ptx @@ -644,7 +639,7 @@ def test_load_cache_modifier(cache): #---------------- def test_noop(device='cuda'): @triton.jit - def kernel(**meta): + def kernel(x): pass x = triton.testing.random((1,), dtype=torch.int32, device=device) kernel[(1, )](x) \ No newline at end of file diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 3b9a1c17f..b9cdc23c7 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) # triton result - op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) @@ -151,8 +151,8 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False) + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) sparse_softmax = triton.ops.blocksparse.softmax( layout, block, diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index dbf1974ce..0751d044d 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -66,8 +66,8 @@ import torch def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually - META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} - configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)] + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators kernel.kernel_decorators = [] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8622333bf..f14f3b135 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor): arg_values = [] for i, arg_name in enumerate(arg_names): if i in self.constants: - cst = triton.language.core._to_ir(self.constants[i], self.builder) + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) else: if i in self.attributes: @@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor): fn.add_attr(i + 1, attr) fn.args[i].name = arg_name arg_values.append(fn.args[i]) + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: @@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor): ast.NodeVisitor.generic_visit(self, node) return node.arg + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == triton.language.constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + self.lscope[target] = triton.language.constexpr(value) + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): _names = [] for target in node.targets: @@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + if isinstance(value, triton.language.constexpr): + value = value.value if not isinstance(value, triton.language.block): value = triton.language.core._to_ir(value, self.builder) self.set_value(name, value) @@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) + if isinstance(lhs, triton.language.core.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.core.constexpr): + rhs = rhs.value fn = { ast.Add: '__add__', ast.Sub: '__sub__', @@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - kws = dict() - if self.is_triton_object(lhs): - kws['_builder'] = self.builder - ret = getattr(lhs, fn)(rhs, **kws) - if ret is NotImplemented: - if self.is_triton_object(rhs): - kws['_builder'] = self.builder + return getattr(lhs, fn)(rhs, _builder=self.builder) + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] - ret = getattr(rhs, fn)(lhs, **kws) - return ret + return getattr(rhs, fn)(lhs, _builder=self.builder) + else: + return getattr(lhs, fn)(rhs) def visit_If(self, node): cond = self.visit(node.test) @@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) + if isinstance(lhs, triton.language.core.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.core.constexpr): + rhs = rhs.value fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_UnaryOp(self, node): op = self.visit(node.operand) + if isinstance(op, triton.language.core.constexpr): + op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor): return fn(*args, **kws) def visit_Num(self, node): - return node.n + return triton.language.constexpr(node.n) def visit_Attribute(self, node): lhs = self.visit(node.value) @@ -477,6 +505,8 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] + if isinstance(obj, triton.language.core.constexpr): + obj = obj.value if isinstance(obj, int): if abs(obj) <= 0xffffffff: return 'I' @@ -485,6 +515,8 @@ class Kernel: return 'f' if isinstance(obj, bool): return 'B' + if isinstance(obj, str): + return 'str' assert False @@ -537,7 +569,8 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): + def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): + wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -547,7 +580,7 @@ class Kernel: # generate Triton-IR # export symbols visible from self.fn into code-generator object gscope = sys.modules[self.fn.module].__dict__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) except Exception as e: @@ -566,7 +599,19 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + # handle arguments passed by name + kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} + wargs = list(wargs) + for i, pos in enumerate(sorted(kwargs)): + wargs.insert(pos + i, kwargs[pos]) + if len(wargs) != len(self.fn.arg_names): + raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") + # handle annotations + for name, type in self.fn.__annotations__.items(): + pos = self.fn.arg_names.index(name) + assert type == triton.language.core.constexpr + wargs[pos] = type(wargs[pos]) # device inference tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: @@ -601,18 +646,19 @@ class Kernel: args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ if isinstance(a, int) and i not in self.fn.do_not_specialize} + # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + # compute hash for caching this kernel types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) attr_key = tuple(attributes.items()) - meta_key = tuple(sorted(meta.items())) const_key = tuple(constants.items()) compute_capability = torch.cuda.get_device_capability(device) - key = ( self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, meta_key, const_key + types_key, attr_key, num_warps, num_stages, const_key ) key = repr(key) @@ -644,7 +690,7 @@ class Kernel: binary = self._compile( *wargs, device=device_idx, attributes=attributes, num_warps=num_warps, num_stages=num_stages, - constants=constants, **meta + constants=constants, ) if bin_cache_path: assert bin_lock_path is not None @@ -657,12 +703,15 @@ class Kernel: drv_cache[key] = LoadedBinary(device_idx, binary) # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)]) - params = struct.pack(fmt, *args) + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) + params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) # enqueue cached function into stream callable = drv_cache[key] stream = torch.cuda.current_stream(device_idx).cuda_stream - grid = grid(meta) if hasattr(grid, '__call__') else grid + csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} + grid = grid(csts) if hasattr(grid, '__call__') else grid + if isinstance(grid, int): + grid = tuple(grid) callable(stream, params, *grid) return callable @@ -697,31 +746,31 @@ class Autotuner: def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner - conflicts = meta.keys() & config.meta.keys() + conflicts = meta.keys() & config.kwargs.keys() if conflicts: raise ValueError( f"Conflicting meta-parameters: {', '.join(conflicts)}." " Make sure that you don't re-define auto-tuned symbols." ) # augment meta-parameters with tunable ones - current = dict(meta, **config.meta) + current = dict(meta, **config.kwargs) def kernel_call(): self.hook(args) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) - def __call__(self, *args, **meta): + def __call__(self, *args, **kwargs): if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: - timings = {config: self._bench(*args, config=config, **meta) \ + timings = {config: self._bench(*args, config=config, **kwargs) \ for config in self.configs} self.cache[key] = builtins.min(timings, key=timings.get) self.hook(args) config = self.cache[key] else: config = self.configs[0] - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) + return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @functools.lru_cache() @@ -769,6 +818,8 @@ class JITFunction: # when called with a grid using __getitem__ self.kernel_decorators = [] self.kernel = None + # annotations + self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ @@ -839,8 +890,8 @@ class Config: Mostly useful for matrix multiplication workloads on SM80+ GPUs. :type num_stages: int """ - def __init__(self, meta, num_warps=4, num_stages=2): - self.meta = meta + def __init__(self, kwargs, num_warps=4, num_stages=2): + self.kwargs = kwargs self.num_warps = num_warps self.num_stages = num_stages diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e584aecb1..5eed3b67f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -14,9 +14,11 @@ def _to_ir(x, builder): return builder.get_int64(x) elif isinstance(x, float): return builder.get_float32(x) - if isinstance(x, block): + elif isinstance(x, constexpr): + return _to_ir(x.value, builder) + elif isinstance(x, block): return x.handle - if isinstance(x, dtype): + elif isinstance(x, dtype): return x.handle(builder) return x @@ -257,6 +259,86 @@ class block: return frontend.cast(self, dtype, _builder) +# ----------------------- +# constexpr +# ----------------------- + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + self.value = value + + def __add__(self, other): + return self.value + other.value + + def __radd__(self, other): + return other.value + self.value + + def __sub__(self, other): + return self.value - other.value + + def __rsub__(self, other): + return other.value - self.value + + def __mul__(self, other): + return self.value * other.value + + def __rmul__(self, other): + return other.value * self.value + + def __truediv__(self, other): + return self.value / other.value + + def __rtruediv__(self, other): + return other.value / self.value + + def __floordiv__(self, other): + return self.value // other.value + + def __rfloordiv__(self, other): + return other.value // self.value + + # + + def __gt__(self, other): + return self.value > other.value + + def __rgt__(self, other): + return other.value > self.value + + def __ge__(self, other): + return self.value >= other.value + + def __rge__(self, other): + return other.value >= self.value + + def __lt__(self, other): + return self.value < other.value + + def __rlt__(self, other): + return other.value < self.value + + def __le__(self, other): + return self.value <= other.value + + def __rle__(self, other): + return other.value <= self.value + + def __eq__(self, other): + return self.value == other.value + + def __ne__(self, other): + return self.value != other.value + + def __bool__(self): + return bool(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + # ----------------------- # SPMD Programming Model # ----------------------- @@ -312,7 +394,12 @@ def zeros(shape, dtype, _builder=None): :param dtype: Data-type of the new array, e.g., :code:`tl.float16` :type dtype: DType """ - shape = [int(x.handle) if isinstance(x, block) else x for x in shape] + for i, d in enumerate(shape): + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] return frontend.zeros(shape, dtype, _builder) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index bcba2e505..8a020e5c2 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,6 +1,5 @@ import triton import triton.language as tl -import triton._C.libtriton as libtriton import torch # ******************************************************** @@ -21,54 +20,46 @@ def _sdd_kernel( stride_za, stride_ha, stride_ma, stride_ak, stride_zb, stride_hb, stride_bk, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, **meta + K, grid_offset, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# - pid1 = tl.program_id(1) + grid_offset - blockidm = tl.arange(0, TILE_M) // BLOCK - blockidn = tl.arange(0, TILE_N) // BLOCK - offlutm = blockidm * (TILE_N // BLOCK) * 4 - offlutn = blockidn * 4 - header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4 - # batch offset - off_z = tl.program_id(2) - # head offset - off_h = tl.load(header + 0) + block_id = tl.program_id(1) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + # initialize pointers to A - start_am = tl.load(header + 1 + offlutm) + start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + off_z * stride_za \ + a_ptrs = A + (off_z * stride_za \ + off_h * stride_ha \ + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ak + + offs_ak[None, :] * stride_ak) # initialize pointers to B - start_bn = tl.load(header + 2 + offlutn) + start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + off_z * stride_zb \ + b_ptrs = B + (off_z * stride_zb \ + off_h * stride_hb \ + offs_bn[None, :] * stride_nb \ - + offs_bk[:, None] * stride_bk + + offs_bk[:, None] * stride_bk) ## ---------------- ## ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(K, 0, -TILE_K): - if meta['EVEN_K']: + if EVEN_K: a = tl.load(a_ptrs) b = tl.load(b_ptrs) else: a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) acc += tl.dot(a, b) a_ptrs += TILE_K * stride_ak b_ptrs += TILE_K * stride_bk @@ -76,22 +67,15 @@ def _sdd_kernel( ## ---------------- ## ## Epilogue ## ## ---------------- ## - blockidm = tl.arange(0, TILE_M) // BLOCK - blockidn = tl.arange(0, TILE_N) // BLOCK - offlutm = blockidm * (TILE_N // BLOCK) * 4 - offlutn = blockidn * 4 - off_block_id = 3 + offlutm[:, None] + offlutn[None, :] - block_id = tl.load(header + off_block_id) - # initialize pointers to C offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + off_z * stride_zc \ + pc = C + (off_z * stride_zc \ + block_id * stride_hc \ + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc + + offs_cn[None, :] * stride_nc) tl.store(pc, c, mask=True) -def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): # (A * B)^T = B^T * A^T if trans_c: a, b = b, a @@ -102,46 +86,28 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, Ka, Kb = a.shape[a_dim], b.shape[b_dim] if Ka != Kb: raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") - if Ka % 16 != 0: - raise ValueError('Reduction size for SDD must be a multiple of 16') # allocate output - n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)]) - c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device) - # each iteration of the loop below - # computes the value for one group of super-blocks - # (e.g., all 4x4 super-blocks) - for lut, width, pack in zip(luts, widths, packs): - # maximum grid size in Triton/CUDA is 64k but we may have more - # super-blocks than that. - max_grid = 65535 - for off_grid in range(0, width, max_grid): - grid = [1, min(max_grid, width - off_grid), c.shape[0]] - # fmt: off - pgm = _sdd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - Ka, off_grid, lut, - TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3, - num_warps=4, - ) - # print(pgm.asm['ptx']) - # exit() + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [1, c.shape[1], c.shape[0]] + _sdd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(2), c.stride(3), + Ka, 0, lut, + TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, + num_warps=4, + ) return c + def sdd_lut(layout, block, device): - start_width = 128 // block - layout = layout.type(torch.int32) - superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width) - luts, widths, packs = [], [], [] - for size, nnz in superblocks: - nnz = nnz.reshape(-1, 4) - width = nnz.shape[0] // (size * size) - luts.append(torch.from_numpy(nnz).type(torch.int32).to(device)) - widths.append(width) - packs.append(size) - return luts, None, widths, packs + lut = layout.nonzero(as_tuple=False).to(device).int() + return lut, None # ----------------------------- # Dense = Sparse x Dense (DSD) @@ -154,12 +120,10 @@ def _dsd_kernel( stride_az, stride_ha, stride_am, stride_ak, stride_zb, stride_hb, stride_bk, stride_bn, stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, **meta + DS0, DS1, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - GROUP_SIZE_M = meta['GROUP_SIZE_M'] #------------# #- Prologue -# #------------# @@ -167,9 +131,9 @@ def _dsd_kernel( pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pidz = tl.program_id(2) - header = lut + pid_m * 4 + header = lut + pid_n * 4 offset = tl.load(header + 0) K = tl.load(header + 1) column = tl.load(header + 2) @@ -185,7 +149,8 @@ def _dsd_kernel( + offs_am[:, None] * stride_am \ + offs_ak[None, :] * stride_ak # initialize pointers to B (dense) - offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N) + offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) start_bk = tl.load(pinc) start_bk = tl.multiple_of(start_bk, 8) # compiler hint offs_bk = start_bk + tl.arange(0, TILE_K) @@ -197,28 +162,33 @@ def _dsd_kernel( ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) for k in range(K, 0, -TILE_K): a = tl.load(pa, mask=True) b = tl.load(pb, mask=offs_bn[None, :] < DS0) acc += tl.dot(a, b) + pa += inc_a + pb += inc_b*stride_bk pinc += 2 inc_a = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) inc_b = tl.load(pinc) inc_b = tl.multiple_of(inc_b, 8) - pa += inc_a - pb += inc_b*stride_bk c = acc.to(C.dtype.element_ty) # initialize pointers to C offs_cm = column*TILE_M + tl.arange(0, TILE_M) - offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N) + offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) pc = C + off_h * stride_hc \ + pidz * stride_zc \ + offs_cm[:, None] * stride_cm \ + offs_cn[None, :] * stride_cn tl.store(pc, c, mask = offs_cn[None, :] < DS0) -def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): # shapes / dtypes AS1 = block * spdims[2 if trans_a else 1] BS0 = b.size(0) @@ -230,11 +200,15 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out # meta-parameter heuristics - TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block] + TILE_N = 128 # compute output - grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0] + grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] # fmt: off _dsd_kernel[grid]( a, b, c, @@ -242,8 +216,8 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), BS3, AS1, lut, - TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3, - num_warps=4, GROUP_SIZE_M=8, + TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + num_warps=4, GROUP_SIZE_M=4, ) # exit() return c @@ -323,7 +297,7 @@ def dsd_lut(layout, block, step, trans, device): lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device) # create locks - return lut, None, width, None + return lut, width # ----------------------------- # Dense = Dense x Sparse (DDS) @@ -334,12 +308,10 @@ def _dds_kernel( stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_bk, stride_bn, stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, **meta + DS0, DS1, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): - TILE_M = meta['TILE_M'] - TILE_N = meta['TILE_N'] - TILE_K = meta['TILE_K'] - GROUP_SIZE_M = meta['GROUP_SIZE_M'] #------------# #- Prologue -# #------------# @@ -347,16 +319,17 @@ def _dds_kernel( pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) pid_z = tl.program_id(2) - header = lut + pid_m * 4 + header = lut + pid_n * 4 offset = tl.load(header + 0) AS1 = tl.load(header + 1) column = tl.load(header + 2) off_h = tl.load(header + 3) pinc = lut + offset # initialize pointers to A (dense) - offs_am = pid_n*TILE_M + tl.arange(0, TILE_M) + offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) start_ak = tl.load(pinc) start_ak = tl.multiple_of(start_ak, 8) offs_ak = start_ak + tl.arange(0, TILE_K) @@ -394,7 +367,7 @@ def _dds_kernel( ## ---------------- ## c = acc.to(C.dtype.element_ty) # initialize pointers to C (dense) - offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M) + offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cn = column * TILE_N + tl.arange(0, TILE_N) ptrs_c = C + off_h * stride_hc \ + pid_z * stride_zc \ @@ -403,7 +376,7 @@ def _dds_kernel( # write back tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) -def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) @@ -415,9 +388,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w CS1 = AS1 CS2 = BS2 if trans_c else AS2 CS3 = AS2 if trans_c else BS2 - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] - grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0] + grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] # fmt: off _dds_kernel[grid]( a, b, c, @@ -425,8 +402,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), AS2, BS2, lut, - TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3, - num_warps=4, GROUP_SIZE_M=8, + TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + num_warps=4, GROUP_SIZE_M=4, ) return c @@ -439,25 +416,23 @@ class _matmul(torch.autograd.Function): @staticmethod def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks, - da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, + c_lut, c_width, da_lut, da_width, db_lut, db_width, out ): - c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs) + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) # save for backward ctx.save_for_backward(a, b) - ctx.da_num_locks = da_num_locks ctx.da_lut = da_lut ctx.da_width = da_width - ctx.da_packs = da_packs ctx.db_lut = db_lut - ctx.db_num_locks = db_num_locks ctx.db_width = db_width - ctx.db_packs = db_packs ctx.mode = mode ctx.spdims = spdims ctx.block = block ctx.trans_a = trans_a ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None return c @staticmethod @@ -466,155 +441,55 @@ class _matmul(torch.autograd.Function): a, b = ctx.saved_tensors da, db = None, None mode = ctx.mode - # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2] da = _matmul.fn[mode_da]( - dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width, - ctx.da_packs + dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width, ) # gradients w.r.t. b if ctx.needs_input_grad[1]: mode_db = mode[2] + mode[1] + mode[0] db = _matmul.fn[mode_db]( - a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width, - ctx.db_packs + a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, ) + dout = dc if ctx.has_out else None return da, db, None, None, None,\ None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None - + None, None, None, None, None, dout class matmul: - def make_lut(self, dtype, device): - key = (dtype, device) - if key in self.lut_cache: - return self.lut_cache[key] - # C look-up table - layout, block = self.layout, self.block - step = min(block, 32) - if self.mode == 'sdd': - c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device) - elif self.mode == 'dsd': - c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device) - elif self.mode == 'dds': - c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device) - # DA look-up table - if self.mode == 'sdd': - da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device) - elif self.mode == 'dsd': - da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device) - elif self.mode == 'dds': - da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device) - # DB look-up table - if self.mode == 'sdd': - db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device) - elif self.mode == 'dsd': - db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device) - elif self.mode == 'dds': - db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device) - self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs, - da_lut, da_num_locks, da_width, da_packs, - db_lut, db_num_locks, db_width, db_packs) - return self.lut_cache[key] - def __init__(self, layout, block, mode, trans_a=False, trans_b=False): + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): if mode not in ['sdd', 'dsd', 'dds']: raise NotImplementedError('Supported modes are: sdd, dsd, dds') - # look-up table cache - self.lut_cache = dict() - # attributes self.block = block self.mode = mode self.trans_a = trans_a self.trans_b = trans_b - - layout_dim = layout.ndim - assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s" - - if not mode == 'sdd': - # Dims to be reduced on the 'inside' of the matmul, either -1 or -2 - trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2) - self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner - sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1) - - # Inner dim of the dense input should be equal to the inner dim of the sparse input - self.dense_inner_size = layout.shape[sparse_inner] * block - # Expected shape for sparse inputs - self.sparse_shape = (layout.sum().item(), block, block) - - # Support using the same layout across attention heads etc. - if layout_dim == 2: - layout = layout.unsqueeze(0) - - layout = layout.long() # Above code assumes the layout tensor is an integral type + self.trans_c = trans_c self.layout = layout self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) - def __call__(self, a, b): - c_lut, c_num_locks, c_width, c_packs,\ - da_lut, da_num_locks, da_width, da_packs,\ - db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) - - # If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior - # and potential illegal memory accesses - original_dims = max(a.ndim, b.ndim) - a, b = self._validate_inputs(a, b) - - # execute + def __call__(self, a, b, out = None): c = _matmul.apply( - a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, - c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs + a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, + self.c_lut, self.c_width, + self.da_lut, self.da_width, + self.db_lut, self.db_width, + out ) - # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input - dims_to_trim = c.ndim - original_dims - for _ in range(dims_to_trim): - c = c.squeeze(0) - - return c - - def _validate_inputs(self, a, b): - if a.device != b.device: - raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A " - f"and {b.device} for tensor B") - if not a.is_cuda: - raise ValueError("Only GPU devices are supported for now") - - # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here - if torch.is_autocast_enabled(): - a, b = a.half(), b.half() - elif a.dtype != b.dtype: - raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B") - - mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b - if mode != 'sdd': - # One input is sparse - dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A') - dense_inner = dense.shape[self.dense_inner_dim] - if dense_inner != self.dense_inner_size: - raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim " - f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.") - - if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape: - raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument " - f"{sparse_name}, got {sparse.shape}") - - def add_extra_dims(x): - # Add extra leading singleton dimensions if needed - dims_needed = 4 - x.ndim - if dims_needed > 0: - singletons = [1] * dims_needed - x = x.view(*singletons, *x.shape) - elif dims_needed < 0: - raise ValueError("Tensors with more than 4 dimensions are not currently supported") - - return x - - # Pad shapes with leading singleton dimensions - a = add_extra_dims(a) - b = add_extra_dims(b) - - return a, b + return c \ No newline at end of file diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 10f806af2..5b9d752ec 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -16,10 +16,9 @@ def num_warps(n): @triton.jit def _forward( X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - **meta + TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr, + KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr, ): - TN = meta['TN'] - BLOCK = meta['BLOCK'] pidhm = tl.program_id(0) pidz = tl.program_id(1) # create index ranges @@ -43,25 +42,25 @@ def _forward( x = tl.load(px, mask=check, other=-float('inf')) x = x.to(tl.float32) # apply scale - if meta['APPLY_SCALE']: + if APPLY_SCALE: x = x * scale # apply RPE - if meta['APPLY_RPE']: + if APPLY_RPE: prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn rpe = tl.load(prpe, mask=check, other=0) x = x + rpe # apply key-padding mask - if meta['APPLY_KP_MASK']: + if APPLY_KP_MASK: pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) - if meta['KP_MASK_MUL']: + if KP_MASK_MUL: kp_m = tl.where(kp_m == 0, -float('inf'), 0.) x = x + kp_m # apply attention mask - if meta['APPLY_ATTN_MASK']: + if APPLY_ATTN_MASK: pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) - if meta['ATTN_MASK_MUL']: + if ATTN_MASK_MUL: attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # apply causal mask @@ -75,11 +74,9 @@ def _forward( @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) @triton.jit -def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): +def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): pidhm = tl.program_id(0) pidz = tl.program_id(1) - TN = meta['TN'] - BLOCK = meta['BLOCK'] # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK @@ -172,8 +169,7 @@ class _softmax(torch.autograd.Function): APPLY_KP_MASK = apply_kp_mask, APPLY_ATTN_MASK = apply_attn_mask, KP_MASK_MUL = (kp_mask_mode == 'mul'), - ATTN_MASK_MUL = (attn_mask_mode == 'mul'), - force_nc_cache = True) + ATTN_MASK_MUL = (attn_mask_mode == 'mul')) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -196,7 +192,7 @@ class _softmax(torch.autograd.Function): # run kernel M = x.shape[0] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) + _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 87833e3c0..8711d5b19 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -26,8 +26,7 @@ def num_warps(N): @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) @triton.jit -def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): - BLOCK = meta['BLOCK'] +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) @@ -52,8 +51,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) @triton.jit -def _backward(PROBS, IDX, DPROBS, N, **meta): - BLOCK = meta['BLOCK'] +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 22f5f6cc2..802908657 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -26,13 +26,9 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - LOCKS, **META): - # extract meta-parameters - BLOCK_M = META['BLOCK_M'] - BLOCK_N = META['BLOCK_N'] - BLOCK_K = META['BLOCK_K'] - GROUP_M = META['GROUP_M'] - SPLIT_K = META['SPLIT_K'] + LOCKS, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -55,7 +51,7 @@ def _kernel(A, B, C, M, N, K, B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K*SPLIT_K): - if META['EVEN_K']: + if EVEN_K: a = tl.load(A) b = tl.load(B) else: @@ -113,14 +109,11 @@ class _matmul(torch.autograd.Function): locks = _matmul._locks[device] # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, - M, N, K, + _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - locks, - GROUP_M=8) - # done + locks, GROUP_M=8) return c @staticmethod diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 4446cf6e9..0934c8ea1 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn # Compute Kernel # -------------------------- +from triton.language.core import constexpr import torch import triton import triton.language as tl @@ -23,9 +24,9 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector - **meta, # Optional meta-parameters for the kernel + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process + # NOTE: `constexpr` so it can be used as a shape value ): - BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -37,8 +38,8 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements - # Load x and y from DRAM, masking out any extar elements in case the input is not a - # multiple of the block size + # Load x and y from DRAM, masking out any extra elements in case + # the input is not a multiple of the block size x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 15efa7c81..2c0cfb9a8 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -65,11 +65,11 @@ import triton.language as tl @triton.jit def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta + output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, + BLOCK_SIZE: tl.constexpr ): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) - BLOCK_SIZE = meta['BLOCK_SIZE'] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 1d9fea638..2d2ab91e9 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -182,17 +182,13 @@ def matmul_kernel( stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters - **meta, -): + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, + ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ - # extract meta-parameters - BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] - BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] - BLOCK_SIZE_K = meta['BLOCK_SIZE_K'] - GROUP_SIZE_M = 8 - # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse From 5d543521644840408bd6277a7a4665fed83de588 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 4 Nov 2021 13:25:24 -0700 Subject: [PATCH 004/215] [FRONTEND] Significantly reduce kernel launch time (#367) --- python/setup.py | 4 +- python/src/triton.cc | 152 ++++++++++++++++++++++++ python/triton/code_gen.py | 203 +++++++++++++-------------------- python/triton/language/core.py | 10 ++ 4 files changed, 245 insertions(+), 124 deletions(-) diff --git a/python/setup.py b/python/setup.py index 21cabe182..17db76093 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,7 +18,7 @@ import tarfile def get_llvm(): # tries to find system LLVM - versions = ['-11.0', '-11', '-11-64'] + versions = ['-11.0', '-11', '-11-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] @@ -127,7 +127,7 @@ setup( description="A language and compiler for custom Deep Learning operations", long_description="", packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], - install_requires=["torch", "filelock"], + install_requires=["cmake", "torch", "filelock"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], diff --git a/python/src/triton.cc b/python/src/triton.cc index 9298f9db4..abe441b3a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -13,6 +13,7 @@ #include #include #include +#include "Python.h" #include #include #include "llvm/IR/Module.h" @@ -23,6 +24,7 @@ namespace py = pybind11; namespace ir = triton::ir; namespace drv = triton::driver; + /*****************************************************************************/ /* Python bindings for triton::driver */ /*****************************************************************************/ @@ -99,8 +101,113 @@ void hip_enqueue(uint64_t stream, uint64_t kernel, } +std::string pow2_divisor(long N){ + if(N % 16 == 0) return "16"; + if(N % 8 == 0) return "8"; + if(N % 4 == 0) return "4"; + if(N % 2 == 0) return "2"; + return "1"; +} + +// Launch +void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names, + std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, + int num_warps, int num_stages) { + size_t len = PyList_Size(args.ptr()); + params.reserve(8*len); // 8 max bytes by argument + char* params_ptr = ¶ms[0]; + cache_key = func_key; + for(int i = 0; i < len; i++){ + auto arg_ptr = PyList_GetItem(args.ptr(), i); + auto arg = py::handle(arg_ptr); + // argument is `long` + if(PyLong_Check(arg_ptr)){ + int overflow; + long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + // long and int have different kernels + if(!overflow & (std::abs(value) <= 0xffffffff)){ + cache_key += 'I'; + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } + else{ + cache_key += 'L'; + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + if(overflow){ + unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); + std::memcpy(&value, &uvalue, 8); + } + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + } + // values equal to 1 are specialized + if(value == 1) + cache_key += '1'; + else + cache_key += 'x'; + // values divisible by small powers of 2 are specialized + cache_key += pow2_divisor(value); + continue; + } + // argument is `float` + if(PyFloat_Check(arg_ptr)){ + cache_key += "f"; + float value = PyFloat_AsDouble(arg_ptr); + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + continue; + } + // argument is `bool` + if(PyBool_Check(arg_ptr)){ + cache_key += "B"; + bool value = arg_ptr == Py_True ? true : false; + std::memcpy(params_ptr, &value, 1); + params_ptr += 1; + continue; + } + // argument is tensor + PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr); + if(data_ptr){ + cache_key += "P"; + long value = PyLong_AsLong(data_ptr); + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &value, 8); + params_ptr += 8; + PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype"); + PyObject* repr = PyObject_Repr(dtype); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.' + size_t len = PyUnicode_GET_LENGTH(repr) - 6; + cache_key += std::string(start, len); + continue; + } + // argument is `constexpr` + PyObject* value = PyObject_GetAttrString(arg_ptr, "value"); + if(value){ + PyObject* name = PyList_GetItem(arg_names.ptr(), i); + PyDict_SetItem(constants, name, value); + PyObject* repr = PyObject_Repr(value); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr); + size_t len = PyUnicode_GET_LENGTH(repr); + cache_key += std::string(start, len); + continue; + } + assert(false); + } + cache_key += std::to_string(num_warps); + cache_key += std::to_string(num_stages); + params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); +} + +// + void init_triton_runtime(py::module &&m) { + // m.def("current_stream", [](uint64_t device){ + // return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream()); + // }); + // wrap backend_t py::enum_(m, "backend") .value("HOST", HOST) @@ -116,6 +223,51 @@ void init_triton_runtime(py::module &&m) { } ); + // cache key + m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names, + py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, + py::handle add_to_cache, py::handle grid){ + // parse arguments to compute cache key, compile-time constants and packed kernel arguments + long _num_warps = PyLong_AsLong(num_warps.ptr()); + long _num_stages = PyLong_AsLong(num_stages.ptr()); + std::string cache_key; + std::string params; + size_t params_size; + PyObject* constants = PyDict_New(); + parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + // get cached binary + PyObject* key = PyUnicode_FromString(cache_key.c_str()); + PyObject* bin = nullptr; + if(!PyDict_Contains(bin_cache.ptr(), key)){ + add_to_cache(py::handle(key), args, device, num_warps, num_stages); + } + bin = PyDict_GetItem(bin_cache.ptr(), key); + // get grid + PyObject* grid_ptr = grid.ptr(); + if(!PySequence_Check(grid_ptr)){ + PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__"); + grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr); + } + int size = PySequence_Size(grid_ptr); + int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0)); + int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1)); + int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2)); + // enqueue + uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel")); + uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem")); + // actually launch + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), + CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, + CU_LAUNCH_PARAM_END + }; + uint64_t _stream = PyLong_AsLong(stream.ptr()); + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, + nullptr, config); + return py::handle(bin); + }); + // query maximum shared memory m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index f14f3b135..b8b9f8129 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -464,6 +464,7 @@ class LoadedBinary: self.module = module self.kernel = kernel self.device = device + self.shared_mem = bin.shared_mem def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, @@ -548,16 +549,6 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) - @staticmethod - def _types_key(*wargs, tensor_idxs): - # type inference - types_key = [None] * len(wargs) - for i, arg in enumerate(wargs): - prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) - types_key[i] = prefix + suffix - return tuple(types_key) - @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -599,6 +590,53 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -608,112 +646,21 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for name, type in self.fn.__annotations__.items(): - pos = self.fn.arg_names.index(name) - assert type == triton.language.core.constexpr - wargs[pos] = type(wargs[pos]) - # device inference - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - if len(tensor_idxs) == 0: - raise ValueError("No Tensor argument found.") - invalid_args = [] - device_ids = [] - for idx in tensor_idxs: - curr = wargs[idx] - if not curr.is_cuda: - invalid_args.append(idx) - else: - device_ids.append(curr.device.index) - if invalid_args: - raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + - " Only CUDA is supported at the moment") - - device = torch.device('cuda', torch.cuda.current_device()) - device_idx = device.index - # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: - # # try to enable P2P communication - # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): - # if dst_idx != device_idx: - # try: - # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) - # except RuntimeError as e: - # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" - # .format(device_idx, dst_idx, str(e))) - - # enqueue kernel on the current device - torch.cuda.set_device(device_idx) - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - - # compute hash for caching this kernel - types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) - attr_key = tuple(attributes.items()) - const_key = tuple(constants.items()) - compute_capability = torch.cuda.get_device_capability(device) - key = ( - self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, const_key - ) - key = repr(key) - - # get cached binary - drv_cache = self.fn.drv_cache - - if key not in drv_cache: - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - drv_cache[key] = LoadedBinary(device_idx, binary) - # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) - params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) - # enqueue cached function into stream - callable = drv_cache[key] - stream = torch.cuda.current_stream(device_idx).cuda_stream - csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} - grid = grid(csts) if hasattr(grid, '__call__') else grid - if isinstance(grid, int): - grid = tuple(grid) - callable(stream, params, *grid) - return callable + for pos, _type in self.fn.annotations.items(): + wargs[pos] = _type(wargs[pos]) + # query device index and cuda stream + device = torch.cuda.current_device() + # query stream + # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # and require different wheels for different torch versions -- undesirable! + bits = torch._C._cuda_getCurrentStream(device) + mask = 1 << 47 + stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + # make key for cache + return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream, + self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) class Launcher: @@ -723,6 +670,7 @@ class Launcher: def __call__(self, *wargs, **kwargs): return self.kernel(*wargs, **kwargs, grid=self.grid) + class Autotuner: @@ -773,6 +721,11 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) +@functools.lru_cache() +def compute_capability(): + device = torch.device('cuda', 0) + return '-'.join(map(str, torch.cuda.get_device_capability(device))) + @functools.lru_cache() def version_key(): import pkgutil @@ -784,22 +737,27 @@ def version_key(): with open(triton._C.libtriton.__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] # language - for lib in pkgutil.iter_modules(triton.language.__path__): + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] # ptxas version try: ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() except Exception: - ptxas_version = None - return (triton.__version__, ptxas_version) + tuple(contents) + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) class JITFunction: cache_hook = None def _set_cache_key(self): - self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version) + self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest() + self.cache_key += str(self.version) + self.cache_key += version_key() + self.cache_key += compute_capability() + self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest() def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -811,7 +769,7 @@ class JITFunction: self.do_not_specialize = [] if do_not_specialize is None else\ [self.arg_names.index(arg) for arg in do_not_specialize] # cache for callable driver objects (e.g. CUkernel) - self.drv_cache = dict() + self.bin_cache = dict() # cache for binaries (on-disk) self._set_cache_key() # JITFunction can be instantiated as kernel @@ -819,6 +777,7 @@ class JITFunction: self.kernel_decorators = [] self.kernel = None # annotations + self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ @@ -834,7 +793,7 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator, **meta): + def __call__(self, *args, generator: CodeGenerator): try: gscope = generator.gscope.copy() lscope = generator.lscope.copy() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5eed3b67f..7875a30f6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -119,6 +119,9 @@ class block: self.shape = (1, ) if self.handle.type.is_block(): self.shape = self.handle.type.shape + self.numel = 1 + for s in self.shape: + self.numel *= s # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) @@ -352,6 +355,13 @@ def program_id(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ + # if axis == -1: + # pid0 = frontend.program_id(0, _builder) + # pid1 = frontend.program_id(1, _builder) + # pid2 = frontend.program_id(2, _builder) + # npg0 = frontend.num_programs(0, _builder) + # npg1 = frontend.num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 return frontend.program_id(axis, _builder) From 9a02dddf29803d5229907d01ebc4e7b4edd179f9 Mon Sep 17 00:00:00 2001 From: daadaada Date: Tue, 9 Nov 2021 00:25:05 +0800 Subject: [PATCH 005/215] Fix sdd_lut (#368) --- python/triton/ops/blocksparse/matmul.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 8a020e5c2..49497777a 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -107,6 +107,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out def sdd_lut(layout, block, device): lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() return lut, None # ----------------------------- From f7ab96cfd754acd304bf83953ea33da3086e0b85 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 5 Nov 2021 09:26:33 -0700 Subject: [PATCH 006/215] [FRONTEND] Fixed some issues with `constexpr` --- python/triton/code_gen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b8b9f8129..22b910f5a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -152,7 +152,9 @@ class CodeGenerator(ast.NodeVisitor): if target in self.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') - self.lscope[target] = triton.language.constexpr(value) + if not isinstance(value, triton.language.constexpr): + value = triton.language.constexpr(value) + self.lscope[target] = value return self.lscope[target] # default: call visit_Assign return self.visit_Assign(node) From e66bf76354a63d1c0d49a2534b51031947ee5251 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 12 Nov 2021 00:55:00 -0800 Subject: [PATCH 007/215] [RUNTIME] Bunch of bugfixes (#372) --- lib/codegen/selection/generator.cc | 4 +- python/test/unit/runtime/test_cache.py | 66 +++++++++++++++++++++++ python/triton/__init__.py | 1 - python/triton/code_gen.py | 73 ++++++++++++++++++++------ python/triton/language/random.py | 1 + 5 files changed, 128 insertions(+), 17 deletions(-) create mode 100644 python/test/unit/runtime/test_cache.py diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 7316e047a..eeabb6841 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -782,11 +782,13 @@ void generator::visit_cat_inst(ir::cat_inst* x) { ir::value* lhs = x->get_operand(0); ir::value* rhs = x->get_operand(1); int i = 0; - for(size_t j = 0; j < idxs_.at(lhs).size(); j ++) + for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; + } for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; } +// std::cout << "!" << std::endl; } diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..215b90d8b --- /dev/null +++ b/python/test/unit/runtime/test_cache.py @@ -0,0 +1,66 @@ +import torch +import triton +from triton.code_gen import JITFunction +import triton.language as tl +import os +import shutil + +tmpdir = ".tmp" + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +def apply_src_change(target, old, new): + delattr(kernel.fn, 'hash') + delattr(function_1.fn, 'hash') + delattr(function_2.fn, 'hash') + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + +def test_reuse(): + counter = 0 + def inc_counter(key, binary): + nonlocal counter + counter += 1 + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device='cuda') + for i in range(10): + kernel[(1,)](x, 43, BLOCK=1024) + assert counter == 1 diff --git a/python/triton/__init__.py b/python/triton/__init__.py index a76df9b75..4b8c54703 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,7 +6,6 @@ __version__ = '1.1.1' import torch # submodules from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret - from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 22b910f5a..d418cb9d1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -412,6 +412,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_Call(self, node): fn = self.visit(node.func) + if isinstance(fn, triton.language.constexpr): + fn = fn.value kws = dict() for keyword in node.keywords: kws.update(self.visit(keyword)) @@ -652,6 +654,9 @@ class Kernel: wargs[pos] = _type(wargs[pos]) # query device index and cuda stream device = torch.cuda.current_device() + torch.cuda.set_device(device) + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) # query stream # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 @@ -660,8 +665,9 @@ class Kernel: bits = torch._C._cuda_getCurrentStream(device) mask = 1 << 47 stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + # stream = torch.cuda.current_stream(device).cuda_stream # make key for cache - return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream, + return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) @@ -723,11 +729,6 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) -@functools.lru_cache() -def compute_capability(): - device = torch.device('cuda', 0) - return '-'.join(map(str, torch.cuda.get_device_capability(device))) - @functools.lru_cache() def version_key(): import pkgutil @@ -750,16 +751,49 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) +#########################3 + + +class DependenciesFinder(ast.NodeVisitor): + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if isinstance(func, triton.JITFunction): + func = func.fn + module = inspect.getmodule(func) + if module and module.__name__.startswith('triton.'): + return + if not hasattr(func, 'hash'): + src = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(src) + finder = DependenciesFinder(func.__globals__, src) + finder.visit(tree) + func.hash = finder.ret + self.ret = (self.ret + func.hash).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + class JITFunction: cache_hook = None - def _set_cache_key(self): - self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest() - self.cache_key += str(self.version) - self.cache_key += version_key() - self.cache_key += compute_capability() - self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest() def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -772,8 +806,6 @@ class JITFunction: [self.arg_names.index(arg) for arg in do_not_specialize] # cache for callable driver objects (e.g. CUkernel) self.bin_cache = dict() - # cache for binaries (on-disk) - self._set_cache_key() # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = [] @@ -785,6 +817,15 @@ class JITFunction: self.__doc__ = fn.__doc__ + @property + @functools.lru_cache() + def cache_key(self): + if not hasattr(self.fn, 'hash'): + dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.fn.hash = dependencies_finder.ret + return self.fn.hash + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Some unit tests do this, for example. @@ -821,7 +862,9 @@ class JITFunction: self.kernel = None super(JITFunction, self).__setattr__(name, value) if name == 'src': - self._set_cache_key() + if hasattr(self.fn, 'hash'): + delattr(self.fn, 'hash') + JITFunction.cache_key.fget.cache_clear() def _init_kernel(self): if self.kernel is None: diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 3a3d7f9e1..9bb29588a 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -110,6 +110,7 @@ def randint4x(seed, offset): :param offsets: The offsets to generate random numbers for. """ z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting + seed = seed + 0 seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32) From 01cc3d4503f40b7996f039edde3e5dcf3c8f2ff7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 12 Nov 2021 15:06:55 -0800 Subject: [PATCH 008/215] [RUNTIME] Restored `do_not_specialize` (#374) --- python/src/triton.cc | 10 ++++++--- python/test/unit/runtime/test_cache.py | 30 +++++++++++++++++++++++--- python/triton/code_gen.py | 2 +- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index abe441b3a..5fb6bb8f2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -110,7 +110,7 @@ std::string pow2_divisor(long N){ } // Launch -void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names, +void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names, std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, int num_warps, int num_stages) { size_t len = PyList_Size(args.ptr()); @@ -118,6 +118,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n char* params_ptr = ¶ms[0]; cache_key = func_key; for(int i = 0; i < len; i++){ + PyObject* py_i = PyLong_FromLong(i); + bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i); auto arg_ptr = PyList_GetItem(args.ptr(), i); auto arg = py::handle(arg_ptr); // argument is `long` @@ -141,6 +143,8 @@ void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_n std::memcpy(params_ptr, &value, 8); params_ptr += 8; } + if(!specialize) + continue; // values equal to 1 are specialized if(value == 1) cache_key += '1'; @@ -224,7 +228,7 @@ void init_triton_runtime(py::module &&m) { ); // cache key - m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names, + m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names, py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, py::handle add_to_cache, py::handle grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments @@ -234,7 +238,7 @@ void init_triton_runtime(py::module &&m) { std::string params; size_t params_size; PyObject* constants = PyDict_New(); - parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); // get cached binary PyObject* key = PyUnicode_FromString(cache_key.c_str()); PyObject* bin = nullptr; diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 215b90d8b..a1c994241 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -4,6 +4,7 @@ from triton.code_gen import JITFunction import triton.language as tl import os import shutil +import pytest tmpdir = ".tmp" @@ -25,6 +26,11 @@ def kernel(X, i, BLOCK: tl.constexpr): i = function_1(i) tl.store(X, i) +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) def apply_src_change(target, old, new): delattr(kernel.fn, 'hash') @@ -51,16 +57,34 @@ def test_nested1_change(): updated = apply_src_change(function_1, 'i + 1', 'i + 2') assert baseline != updated +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + def test_reuse(): counter = 0 def inc_counter(key, binary): nonlocal counter counter += 1 - os.environ["TRITON_CACHE_DIR"] = tmpdir - if os.path.exists(tmpdir): - shutil.rmtree(tmpdir) JITFunction.cache_hook = inc_counter + reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): kernel[(1,)](x, 43, BLOCK=1024) assert counter == 1 + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode): + counter = 0 + def inc_counter(key, binary): + nonlocal counter + counter += 1 + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 5, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1,)](x, i, BLOCK=512) + assert counter == target diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d418cb9d1..73d114ed1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -667,7 +667,7 @@ class Kernel: stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask # stream = torch.cuda.current_stream(device).cuda_stream # make key for cache - return _triton.runtime.launch(wargs, self.fn.cache_key + cc, self.fn.arg_names, device, stream, + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) From b908095872450b237b623155b1e252da2c9d31c1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 12 Nov 2021 15:10:04 -0800 Subject: [PATCH 009/215] [VERSION] Bumped triton.__version__ to 2.0.0 --- python/triton/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 4b8c54703..c079880e9 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,5 +1,5 @@ # version -__version__ = '1.1.1' +__version__ = '2.0.0' # TODO: torch needs to be imported first # or pybind11 shows `munmap_chunk(): invalid pointer` From 791b953b2174e3ed66ae86d134ff4cdb051d7190 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 16 Nov 2021 00:17:27 -0800 Subject: [PATCH 010/215] [CODEGEN] Reverted to old way to query current stream --- python/triton/code_gen.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 73d114ed1..efcc2701f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -657,15 +657,15 @@ class Kernel: torch.cuda.set_device(device) cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1]) - # query stream - # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # and require different wheels for different torch versions -- undesirable! - bits = torch._C._cuda_getCurrentStream(device) - mask = 1 << 47 - stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - # stream = torch.cuda.current_stream(device).cuda_stream + # # query stream + # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # # and require different wheels for different torch versions -- undesirable! + # bits = torch._C._cuda_getCurrentStream(device) + # mask = 1 << 47 + # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + stream = torch.cuda.current_stream(device).cuda_stream # make key for cache return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) From 5b7ba3eb96cf85ff5897eeef9e45b18cb5d8a74b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 16 Nov 2021 01:21:03 -0800 Subject: [PATCH 011/215] [CODEGEN] Reverted to old launch method (memory leak?) --- python/triton/code_gen.py | 360 +++++++++++++++++++++++++++++++------- 1 file changed, 294 insertions(+), 66 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index efcc2701f..cc0a103c9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -493,6 +493,184 @@ class OutOfResources(Exception): super().__init__(self.message) +# class Kernel: +# @staticmethod +# def _type_name(obj): +# type_names = { +# triton.language.float8: 'f8', +# torch.bfloat16: 'bf16', +# torch.float16: 'f16', +# torch.float32: 'f32', +# torch.float64: 'f64', +# torch.bool: 'i1', +# torch.int8: 'i8', +# torch.int16: 'i16', +# torch.int32: 'i32', +# torch.int64: 'i64', +# } +# if hasattr(obj, 'data_ptr'): +# return type_names[obj.dtype] +# if isinstance(obj, triton.language.core.constexpr): +# obj = obj.value +# if isinstance(obj, int): +# if abs(obj) <= 0xffffffff: +# return 'I' +# return 'L' +# if isinstance(obj, float): +# return 'f' +# if isinstance(obj, bool): +# return 'B' +# if isinstance(obj, str): +# return 'str' +# assert False + + + +# @staticmethod +# def _to_triton_ir(context, obj): +# type_map = { +# 'I': _triton.ir.type.get_int32, +# 'L': _triton.ir.type.get_int64, +# 'f': _triton.ir.type.get_fp32, +# 'B': _triton.ir.type.get_int1, +# 'f8': _triton.ir.type.get_fp8, +# 'f16': _triton.ir.type.get_fp16, +# 'bf16': _triton.ir.type.get_bf16, +# 'f32': _triton.ir.type.get_fp32, +# 'f64': _triton.ir.type.get_fp64, +# 'i1': _triton.ir.type.get_int1, +# 'i8': _triton.ir.type.get_int8, +# 'i16': _triton.ir.type.get_int16, +# 'i32': _triton.ir.type.get_int32, +# 'i64': _triton.ir.type.get_int64, +# } +# # convert torch.Tensor to Triton IR pointers +# if hasattr(obj, 'data_ptr'): +# name = Kernel._type_name(obj) +# elt_ty = type_map[name](context) +# return _triton.ir.type.make_ptr(elt_ty, 1) +# # default path returns triton.ir.type directly +# name = Kernel._type_name(obj) +# return type_map[name](context) + +# @staticmethod +# def pow2_divisor(N): +# if N % 16 == 0: return 16 +# if N % 8 == 0: return 8 +# if N % 4 == 0: return 4 +# if N % 2 == 0: return 2 +# return 1 + +# def __init__(self, fn): +# self.fn = fn + +# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): +# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] +# # create IR module +# context = _triton.ir.context() +# # get just-in-time proto-type of kernel +# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] +# ret_type = _triton.ir.type.get_void(context) +# prototype = _triton.ir.type.make_function(ret_type, arg_types) +# # generate Triton-IR +# # export symbols visible from self.fn into code-generator object +# gscope = sys.modules[self.fn.module].__dict__ +# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) +# try: +# generator.visit(self.fn.parse()) +# except Exception as e: +# node = generator.last_node +# if node is None or isinstance(e, (NotImplementedError, CompilationError)): +# raise e +# raise CompilationError(self.fn.src, node, e) +# # Compile to machine code +# if torch.version.hip is None: +# backend = _triton.runtime.backend.CUDA +# else: +# backend = _triton.runtime.backend.ROCM +# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) +# max_shared_memory = _triton.runtime.max_shared_memory(backend, device) +# if shared_mem > max_shared_memory: +# raise OutOfResources(shared_mem, max_shared_memory, "shared memory") +# return Binary(backend, name, asm, shared_mem, num_warps) + +# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): +# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] +# # attributes +# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] +# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ +# if isinstance(a, int) and i not in self.fn.do_not_specialize} + +# # transforms ints whose value is one into constants for just-in-time compilation +# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} +# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) +# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + +# # create cache directory +# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') +# if cache_dir and not os.path.exists(cache_dir): +# os.makedirs(cache_dir, exist_ok=True) + +# if cache_dir: +# bin_cache_path = os.path.join(cache_dir, hashed_key) +# bin_lock_path = bin_cache_path + ".lock" +# else: +# bin_cache_path = None +# bin_lock_path = None + +# binary = None +# if bin_cache_path and os.path.exists(bin_cache_path): +# assert bin_lock_path is not None +# with FileLock(bin_lock_path): +# with open(bin_cache_path, 'rb') as f: +# binary = pickle.load(f)["binary"] +# if binary is None: +# binary = self._compile( +# *wargs, device=device_idx, attributes=attributes, +# num_warps=num_warps, num_stages=num_stages, +# constants=constants, +# ) +# if bin_cache_path: +# assert bin_lock_path is not None +# with FileLock(bin_lock_path): +# with open(bin_cache_path + ".tmp", "wb") as f: +# pickle.dump({"binary": binary, "key": key}, f) +# os.rename(bin_cache_path + ".tmp", bin_cache_path) +# if JITFunction.cache_hook is not None: +# JITFunction.cache_hook(key=key, binary=binary) + +# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + +# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): +# # handle arguments passed by name +# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} +# wargs = list(wargs) +# for i, pos in enumerate(sorted(kwargs)): +# wargs.insert(pos + i, kwargs[pos]) +# if len(wargs) != len(self.fn.arg_names): +# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") +# # handle annotations +# for pos, _type in self.fn.annotations.items(): +# wargs[pos] = _type(wargs[pos]) +# # query device index and cuda stream +# device = torch.cuda.current_device() +# torch.cuda.set_device(device) +# cc = torch.cuda.get_device_capability(device) +# cc = str(cc[0]) + '-' + str(cc[1]) +# # # query stream +# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` +# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 +# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency +# # # and require different wheels for different torch versions -- undesirable! +# # bits = torch._C._cuda_getCurrentStream(device) +# # mask = 1 << 47 +# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask +# stream = torch.cuda.current_stream(device).cuda_stream +# # make key for cache +# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, +# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + + class Kernel: @staticmethod def _type_name(obj): @@ -553,6 +731,16 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) + @staticmethod + def _types_key(*wargs, tensor_idxs): + # type inference + types_key = [None] * len(wargs) + for i, arg in enumerate(wargs): + prefix = 'P' if i in tensor_idxs else '' + suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) + types_key[i] = prefix + suffix + return tuple(types_key) + @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -594,53 +782,6 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -650,25 +791,112 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for pos, _type in self.fn.annotations.items(): - wargs[pos] = _type(wargs[pos]) - # query device index and cuda stream - device = torch.cuda.current_device() - torch.cuda.set_device(device) - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - # # query stream - # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # # and require different wheels for different torch versions -- undesirable! - # bits = torch._C._cuda_getCurrentStream(device) - # mask = 1 << 47 - # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - stream = torch.cuda.current_stream(device).cuda_stream - # make key for cache - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, - self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + for name, type in self.fn.__annotations__.items(): + pos = self.fn.arg_names.index(name) + assert type == triton.language.core.constexpr + wargs[pos] = type(wargs[pos]) + # device inference + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + if len(tensor_idxs) == 0: + raise ValueError("No Tensor argument found.") + invalid_args = [] + device_ids = [] + for idx in tensor_idxs: + curr = wargs[idx] + if not curr.is_cuda: + invalid_args.append(idx) + else: + device_ids.append(curr.device.index) + if invalid_args: + raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + + " Only CUDA is supported at the moment") + + device = torch.device('cuda', torch.cuda.current_device()) + device_idx = device.index + # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: + # # try to enable P2P communication + # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): + # if dst_idx != device_idx: + # try: + # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) + # except RuntimeError as e: + # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" + # .format(device_idx, dst_idx, str(e))) + + # enqueue kernel on the current device + torch.cuda.set_device(device_idx) + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + + # compute hash for caching this kernel + types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) + attr_key = tuple(attributes.items()) + const_key = tuple(constants.items()) + compute_capability = torch.cuda.get_device_capability(device) + key = ( + self.fn.cache_key, version_key(), compute_capability, + types_key, attr_key, num_warps, num_stages, const_key + ) + key = repr(key) + + # get cached binary + bin_cache = self.fn.bin_cache + + if key not in bin_cache: + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + bin_cache[key] = LoadedBinary(device_idx, binary) + # pack arguments + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) + params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) + # enqueue cached function into stream + callable = bin_cache[key] + stream = torch.cuda.current_stream(device_idx).cuda_stream + csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} + grid = grid(csts) if hasattr(grid, '__call__') else grid + if isinstance(grid, int): + grid = tuple(grid) + callable(stream, params, *grid) + return callable class Launcher: From edd4b0c8b7e99271023384e52c8cdbc913328235 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 16 Nov 2021 09:53:34 -0800 Subject: [PATCH 012/215] [CODEGEN] Fixed issue with jit function passed as constexpr --- python/triton/code_gen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cc0a103c9..1e514d3e2 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -700,6 +700,8 @@ class Kernel: return 'B' if isinstance(obj, str): return 'str' + if isinstance(obj, JITFunction): + return '' assert False From 5693b582eac1002c19039cefffa2f70ec747bb77 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 21 Nov 2021 02:30:22 -0800 Subject: [PATCH 013/215] [RUNTIME] Now using pybind11 to avoid memory leaks (#377) --- python/src/triton.cc | 85 ++++----- python/triton/code_gen.py | 362 +++++++------------------------------- 2 files changed, 110 insertions(+), 337 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 5fb6bb8f2..2c165dd06 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -110,18 +110,19 @@ std::string pow2_divisor(long N){ } // Launch -void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names, - std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants, +void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, + std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, int num_warps, int num_stages) { size_t len = PyList_Size(args.ptr()); params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; cache_key = func_key; for(int i = 0; i < len; i++){ - PyObject* py_i = PyLong_FromLong(i); - bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i); - auto arg_ptr = PyList_GetItem(args.ptr(), i); - auto arg = py::handle(arg_ptr); + py::int_ py_i = py::int_(i); + bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); + py::object arg = args[i]; + auto arg_ptr = arg.ptr(); + // argument is `long` if(PyLong_Check(arg_ptr)){ int overflow; @@ -172,28 +173,28 @@ void parse_args(py::handle& args, py::handle do_not_specialize, const std::strin continue; } // argument is tensor - PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr); - if(data_ptr){ + if(py::hasattr(arg, "data_ptr")){ + py::object data_ptr = arg.attr("data_ptr")(); cache_key += "P"; - long value = PyLong_AsLong(data_ptr); + long value = data_ptr.cast(); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); params_ptr += 8; - PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype"); - PyObject* repr = PyObject_Repr(dtype); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.' - size_t len = PyUnicode_GET_LENGTH(repr) - 6; + py::object dtype = arg.attr("dtype"); + py::object repr = py::repr(dtype); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' + size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); continue; } // argument is `constexpr` - PyObject* value = PyObject_GetAttrString(arg_ptr, "value"); + py::object value = arg.attr("value"); if(value){ - PyObject* name = PyList_GetItem(arg_names.ptr(), i); - PyDict_SetItem(constants, name, value); - PyObject* repr = PyObject_Repr(value); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr); - size_t len = PyUnicode_GET_LENGTH(repr); + py::object name = arg_names[i]; + constants[name] = value; + py::object repr = py::repr(value); + const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); + size_t len = PyUnicode_GET_LENGTH(repr.ptr()); cache_key += std::string(start, len); continue; } @@ -228,37 +229,39 @@ void init_triton_runtime(py::module &&m) { ); // cache key - m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names, - py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages, - py::handle add_to_cache, py::handle grid){ + m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, + py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, + py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); long _num_stages = PyLong_AsLong(num_stages.ptr()); std::string cache_key; std::string params; size_t params_size; - PyObject* constants = PyDict_New(); + py::dict constants; parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + // get cached binary - PyObject* key = PyUnicode_FromString(cache_key.c_str()); - PyObject* bin = nullptr; - if(!PyDict_Contains(bin_cache.ptr(), key)){ - add_to_cache(py::handle(key), args, device, num_warps, num_stages); - } - bin = PyDict_GetItem(bin_cache.ptr(), key); + py::str key(cache_key); + if(!bin_cache.contains(key)) + add_to_cache(key, args, device, num_warps, num_stages); + py::object bin = bin_cache[key]; + // get grid - PyObject* grid_ptr = grid.ptr(); - if(!PySequence_Check(grid_ptr)){ - PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__"); - grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr); - } - int size = PySequence_Size(grid_ptr); - int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0)); - int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1)); - int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2)); + py::sequence seq; + if(!PySequence_Check(grid.ptr())) + seq = grid(constants); + else + seq = grid; + int size = seq.size(); + int grid_0 = py::cast(seq[0]); + int grid_1 = size < 2 ? 1 : py::cast(seq[1]); + int grid_2 = size < 3 ? 1 : py::cast(seq[2]); + // enqueue - uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel")); - uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem")); + uint64_t kernel = py::cast(bin.attr("kernel")); + uint64_t shared_mem = py::cast(bin.attr("shared_mem")); + // actually launch void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), @@ -269,7 +272,7 @@ void init_triton_runtime(py::module &&m) { drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); - return py::handle(bin); + return bin; }); // query maximum shared memory diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1e514d3e2..efcc2701f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -493,184 +493,6 @@ class OutOfResources(Exception): super().__init__(self.message) -# class Kernel: -# @staticmethod -# def _type_name(obj): -# type_names = { -# triton.language.float8: 'f8', -# torch.bfloat16: 'bf16', -# torch.float16: 'f16', -# torch.float32: 'f32', -# torch.float64: 'f64', -# torch.bool: 'i1', -# torch.int8: 'i8', -# torch.int16: 'i16', -# torch.int32: 'i32', -# torch.int64: 'i64', -# } -# if hasattr(obj, 'data_ptr'): -# return type_names[obj.dtype] -# if isinstance(obj, triton.language.core.constexpr): -# obj = obj.value -# if isinstance(obj, int): -# if abs(obj) <= 0xffffffff: -# return 'I' -# return 'L' -# if isinstance(obj, float): -# return 'f' -# if isinstance(obj, bool): -# return 'B' -# if isinstance(obj, str): -# return 'str' -# assert False - - - -# @staticmethod -# def _to_triton_ir(context, obj): -# type_map = { -# 'I': _triton.ir.type.get_int32, -# 'L': _triton.ir.type.get_int64, -# 'f': _triton.ir.type.get_fp32, -# 'B': _triton.ir.type.get_int1, -# 'f8': _triton.ir.type.get_fp8, -# 'f16': _triton.ir.type.get_fp16, -# 'bf16': _triton.ir.type.get_bf16, -# 'f32': _triton.ir.type.get_fp32, -# 'f64': _triton.ir.type.get_fp64, -# 'i1': _triton.ir.type.get_int1, -# 'i8': _triton.ir.type.get_int8, -# 'i16': _triton.ir.type.get_int16, -# 'i32': _triton.ir.type.get_int32, -# 'i64': _triton.ir.type.get_int64, -# } -# # convert torch.Tensor to Triton IR pointers -# if hasattr(obj, 'data_ptr'): -# name = Kernel._type_name(obj) -# elt_ty = type_map[name](context) -# return _triton.ir.type.make_ptr(elt_ty, 1) -# # default path returns triton.ir.type directly -# name = Kernel._type_name(obj) -# return type_map[name](context) - -# @staticmethod -# def pow2_divisor(N): -# if N % 16 == 0: return 16 -# if N % 8 == 0: return 8 -# if N % 4 == 0: return 4 -# if N % 2 == 0: return 2 -# return 1 - -# def __init__(self, fn): -# self.fn = fn - -# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): -# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] -# # create IR module -# context = _triton.ir.context() -# # get just-in-time proto-type of kernel -# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] -# ret_type = _triton.ir.type.get_void(context) -# prototype = _triton.ir.type.make_function(ret_type, arg_types) -# # generate Triton-IR -# # export symbols visible from self.fn into code-generator object -# gscope = sys.modules[self.fn.module].__dict__ -# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) -# try: -# generator.visit(self.fn.parse()) -# except Exception as e: -# node = generator.last_node -# if node is None or isinstance(e, (NotImplementedError, CompilationError)): -# raise e -# raise CompilationError(self.fn.src, node, e) -# # Compile to machine code -# if torch.version.hip is None: -# backend = _triton.runtime.backend.CUDA -# else: -# backend = _triton.runtime.backend.ROCM -# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) -# max_shared_memory = _triton.runtime.max_shared_memory(backend, device) -# if shared_mem > max_shared_memory: -# raise OutOfResources(shared_mem, max_shared_memory, "shared memory") -# return Binary(backend, name, asm, shared_mem, num_warps) - -# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): -# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] -# # attributes -# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] -# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ -# if isinstance(a, int) and i not in self.fn.do_not_specialize} - -# # transforms ints whose value is one into constants for just-in-time compilation -# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} -# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) -# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - -# # create cache directory -# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') -# if cache_dir and not os.path.exists(cache_dir): -# os.makedirs(cache_dir, exist_ok=True) - -# if cache_dir: -# bin_cache_path = os.path.join(cache_dir, hashed_key) -# bin_lock_path = bin_cache_path + ".lock" -# else: -# bin_cache_path = None -# bin_lock_path = None - -# binary = None -# if bin_cache_path and os.path.exists(bin_cache_path): -# assert bin_lock_path is not None -# with FileLock(bin_lock_path): -# with open(bin_cache_path, 'rb') as f: -# binary = pickle.load(f)["binary"] -# if binary is None: -# binary = self._compile( -# *wargs, device=device_idx, attributes=attributes, -# num_warps=num_warps, num_stages=num_stages, -# constants=constants, -# ) -# if bin_cache_path: -# assert bin_lock_path is not None -# with FileLock(bin_lock_path): -# with open(bin_cache_path + ".tmp", "wb") as f: -# pickle.dump({"binary": binary, "key": key}, f) -# os.rename(bin_cache_path + ".tmp", bin_cache_path) -# if JITFunction.cache_hook is not None: -# JITFunction.cache_hook(key=key, binary=binary) - -# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) - -# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): -# # handle arguments passed by name -# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} -# wargs = list(wargs) -# for i, pos in enumerate(sorted(kwargs)): -# wargs.insert(pos + i, kwargs[pos]) -# if len(wargs) != len(self.fn.arg_names): -# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") -# # handle annotations -# for pos, _type in self.fn.annotations.items(): -# wargs[pos] = _type(wargs[pos]) -# # query device index and cuda stream -# device = torch.cuda.current_device() -# torch.cuda.set_device(device) -# cc = torch.cuda.get_device_capability(device) -# cc = str(cc[0]) + '-' + str(cc[1]) -# # # query stream -# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` -# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 -# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency -# # # and require different wheels for different torch versions -- undesirable! -# # bits = torch._C._cuda_getCurrentStream(device) -# # mask = 1 << 47 -# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask -# stream = torch.cuda.current_stream(device).cuda_stream -# # make key for cache -# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, -# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) - - class Kernel: @staticmethod def _type_name(obj): @@ -700,8 +522,6 @@ class Kernel: return 'B' if isinstance(obj, str): return 'str' - if isinstance(obj, JITFunction): - return '' assert False @@ -733,16 +553,6 @@ class Kernel: name = Kernel._type_name(obj) return type_map[name](context) - @staticmethod - def _types_key(*wargs, tensor_idxs): - # type inference - types_key = [None] * len(wargs) - for i, arg in enumerate(wargs): - prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg) - types_key[i] = prefix + suffix - return tuple(types_key) - @staticmethod def pow2_divisor(N): if N % 16 == 0: return 16 @@ -784,6 +594,53 @@ class Kernel: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + if isinstance(a, int) and i not in self.fn.do_not_specialize} + + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir and not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + if binary is None: + binary = self._compile( + *wargs, device=device_idx, attributes=attributes, + num_warps=num_warps, num_stages=num_stages, + constants=constants, + ) + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + if JITFunction.cache_hook is not None: + JITFunction.cache_hook(key=key, binary=binary) + + self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -793,112 +650,25 @@ class Kernel: if len(wargs) != len(self.fn.arg_names): raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations - for name, type in self.fn.__annotations__.items(): - pos = self.fn.arg_names.index(name) - assert type == triton.language.core.constexpr - wargs[pos] = type(wargs[pos]) - # device inference - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - if len(tensor_idxs) == 0: - raise ValueError("No Tensor argument found.") - invalid_args = [] - device_ids = [] - for idx in tensor_idxs: - curr = wargs[idx] - if not curr.is_cuda: - invalid_args.append(idx) - else: - device_ids.append(curr.device.index) - if invalid_args: - raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) + - " Only CUDA is supported at the moment") - - device = torch.device('cuda', torch.cuda.current_device()) - device_idx = device.index - # if len(set(device_ids)) != 1 or device_ids[0] != device_idx: - # # try to enable P2P communication - # for arg_idx, dst_idx in zip(tensor_idxs, device_ids): - # if dst_idx != device_idx: - # try: - # _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr()) - # except RuntimeError as e: - # raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}" - # .format(device_idx, dst_idx, str(e))) - - # enqueue kernel on the current device - torch.cuda.set_device(device_idx) - # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ - if isinstance(a, int) and i not in self.fn.do_not_specialize} - - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - - # compute hash for caching this kernel - types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) - attr_key = tuple(attributes.items()) - const_key = tuple(constants.items()) - compute_capability = torch.cuda.get_device_capability(device) - key = ( - self.fn.cache_key, version_key(), compute_capability, - types_key, attr_key, num_warps, num_stages, const_key - ) - key = repr(key) - - # get cached binary - bin_cache = self.fn.bin_cache - - if key not in bin_cache: - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) - - bin_cache[key] = LoadedBinary(device_idx, binary) - # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)]) - params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)]) - # enqueue cached function into stream - callable = bin_cache[key] - stream = torch.cuda.current_stream(device_idx).cuda_stream - csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)} - grid = grid(csts) if hasattr(grid, '__call__') else grid - if isinstance(grid, int): - grid = tuple(grid) - callable(stream, params, *grid) - return callable + for pos, _type in self.fn.annotations.items(): + wargs[pos] = _type(wargs[pos]) + # query device index and cuda stream + device = torch.cuda.current_device() + torch.cuda.set_device(device) + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) + # # query stream + # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` + # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 + # # building a C wrapper to re-use the unpack function would add a build-time torch dependency + # # and require different wheels for different torch versions -- undesirable! + # bits = torch._C._cuda_getCurrentStream(device) + # mask = 1 << 47 + # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask + stream = torch.cuda.current_stream(device).cuda_stream + # make key for cache + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, + self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) class Launcher: From 1296eb877b242da37af0c9c130baf3fc0633a60a Mon Sep 17 00:00:00 2001 From: daadaada Date: Mon, 22 Nov 2021 03:20:59 +0800 Subject: [PATCH 014/215] [RUNTIME] Config hook v2.0 (#373) * Add pre_hook to triton.Config * Use argument names in triton.heuristics * Update base perf * Remove meta from heuristics --- python/test/regression/test_performance.py | 16 +++++------ python/test/unit/operators/test_matmul.py | 3 +- python/triton/code_gen.py | 15 ++++++++-- python/triton/ops/blocksparse/matmul.py | 2 +- python/triton/ops/blocksparse/softmax.py | 8 +++--- python/triton/ops/cross_entropy.py | 8 +++--- python/triton/ops/matmul.py | 32 +++++++++------------- 7 files changed, 44 insertions(+), 40 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 215003447..eff21fdfd 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -29,16 +29,16 @@ matmul_data = { (1024, 1024, 1024 ) : {'v100': 0.466}, (2048, 2048, 2048 ) : {'v100': 0.680}, (4096, 4096, 4096 ) : {'v100': 0.831}, - (8192, 8192, 8192 ) : {'v100': 0.841}, + (8192, 8192, 8192 ) : {'v100': 0.849}, # tall-skinny (16 , 1024, 1024 ) : {'v100': 0.0128}, - (16 , 4096, 4096 ) : {'v100': 0.0558}, + (16 , 4096, 4096 ) : {'v100': 0.0883}, (16 , 8192, 8192 ) : {'v100': 0.101}, - (64 , 1024, 1024 ) : {'v100': 0.049}, - (64 , 4096, 4096 ) : {'v100': 0.211}, + (64 , 1024, 1024 ) : {'v100': 0.073}, + (64 , 4096, 4096 ) : {'v100': 0.228}, (64 , 8192, 8192 ) : {'v100': 0.360}, - (1024, 64 , 1024 ) : {'v100': 0.0469}, - (4096, 64 , 4096 ) : {'v100': 0.198}, + (1024, 64 , 1024 ) : {'v100': 0.0692}, + (4096, 64 , 4096 ) : {'v100': 0.223}, (8192, 64 , 8192 ) : {'v100': 0.323}, # # deep reductions # (64 , 64 , 16384) : {'v100': 0.}, @@ -56,7 +56,7 @@ def test_matmul(M, N, K): a = torch.randn((M, K), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda') fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) cur_gpu_perf = 2.*M*N*K/ms * 1e-9 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) @@ -101,7 +101,7 @@ def test_elementwise(N): y = torch.randn_like(z) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 0751d044d..75241c291 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} - configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)] + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators kernel.kernel_decorators = [] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index efcc2701f..84d77795c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -574,7 +574,7 @@ class Kernel: prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self.fn into code-generator object - gscope = sys.modules[self.fn.module].__dict__ + gscope = self.fn.fn.__globals__ generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) @@ -698,6 +698,7 @@ class Autotuner: for i in self.reset_idx: args[i].zero_() self.hook = _hook + self.arg_names = arg_names def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -711,11 +712,14 @@ class Autotuner: # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) self.hook(args) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) def __call__(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: @@ -726,6 +730,8 @@ class Autotuner: config = self.cache[key] else: config = self.configs[0] + if config.pre_hook != None: + config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @@ -893,11 +899,14 @@ class Config: :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. Mostly useful for matrix multiplication workloads on SM80+ GPUs. :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2): + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_stages = num_stages + self.pre_hook = pre_hook def autotune(configs, key, reset_to_zero=None): @@ -963,7 +972,7 @@ def heuristics(values): def fun(*args, **meta): for v, heur in values.items(): assert v not in meta - meta[v] = heur(*args, **meta) + meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) return kernel(*args, **meta) return fun diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 49497777a..9c3317fe0 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -12,7 +12,7 @@ import torch # ******************************************************** @triton.heuristics({ - 'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0, + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @triton.jit def _sdd_kernel( diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 5b9d752ec..dcf77afc8 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -11,8 +11,8 @@ def num_warps(n): return 16 -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) +@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])}) @triton.jit def _forward( X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, @@ -71,8 +71,8 @@ def _forward( tl.store(px, x, mask=check) -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) -@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) +@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']}) @triton.jit def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): pidhm = tl.program_id(0) diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 8711d5b19..529b6c675 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -23,8 +23,8 @@ def num_warps(N): return 16 -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) -@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) @triton.jit def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) @@ -48,8 +48,8 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): tl.store(LOSS + row, probs) -@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) -@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) @triton.jit def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 802908657..ae404b8d6 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -2,9 +2,11 @@ import torch import triton.language as tl import triton +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() @triton.heuristics({ - 'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0, + 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @triton.autotune( configs=[ @@ -18,6 +20,14 @@ import triton triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), + triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), ], key=['M', 'N', 'K'], ) @@ -26,7 +36,6 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - LOCKS, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication @@ -70,18 +79,7 @@ def _kernel(A, B, C, M, N, K, if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: - LOCKS = LOCKS + tl.program_id(0) - COUNT = LOCKS + tl.num_programs(0) - while tl.atomic_cas(LOCKS, 0, 1) == 1: - pass - count = tl.load(COUNT) - if count == 0: - tl.store(C, acc, mask=mask) - else: - curr = tl.load(C, mask=mask, other=0.) - tl.store(C, acc + curr, mask=mask) - tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K) - tl.atomic_xchg(LOCKS, 0) + tl.atomic_add(C, acc, mask=mask) class _matmul(torch.autograd.Function): @@ -103,17 +101,13 @@ class _matmul(torch.autograd.Function): _, N = b.shape # allocates output c = torch.empty((M, N), device=device, dtype=a.dtype) - # allocate locks for split-k - if a.device not in _matmul._locks: - _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) - locks = _matmul._locks[device] # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - locks, GROUP_M=8) + GROUP_M=8) return c @staticmethod From c86ad9c9abd2af491af53e52d066a6d50e662a07 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 29 Nov 2021 19:11:26 -0800 Subject: [PATCH 015/215] [FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379) --- python/src/triton.cc | 20 ++- python/test/unit/language/test_core.py | 22 +++ python/triton/code_gen.py | 35 ++++- python/triton/language/random.py | 194 ++++++++++--------------- 4 files changed, 149 insertions(+), 122 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 2c165dd06..26c233287 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } // argument is `constexpr` - py::object value = arg.attr("value"); - if(value){ + if(py::hasattr(arg, "value")){ + py::object value = arg.attr("value"); py::object name = arg_names[i]; constants[name] = value; py::object repr = py::repr(value); @@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f cache_key += std::string(start, len); continue; } - assert(false); + std::string ty_str = arg.attr("__class__").attr("__name__").cast(); + std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." + + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; + throw std::runtime_error(err_msg); } cache_key += std::to_string(num_warps); cache_key += std::to_string(num_stages); @@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) { CU_LAUNCH_PARAM_END }; uint64_t _stream = PyLong_AsLong(stream.ptr()); - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, - _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, - nullptr, config); + if(grid_0*grid_1*grid_2 > 0) + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, + nullptr, config); return bin; }); @@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) { const std::string &args, int64_t shared_mem){ void* args_ptr = (void*)args.data(); size_t args_size = args.size(); + // release the gil in case the enqueue blocks + // cuda will block if too many ops are enqueued + Py_BEGIN_ALLOW_THREADS if(backend == HOST) host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == CUDA) cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == ROCM) hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); + Py_END_ALLOW_THREADS }); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9354ec233..98c8c34fa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -634,6 +634,28 @@ def test_load_cache_modifier(cache): # test while # --------------- +# --------------- +# test default +# --------------- +#TODO: can't be local to test_default +@triton.jit +def _impl(value = 10): + return value + +def test_default(): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') + ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') + + @triton.jit + def _kernel(ret0, ret1, value): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1,)](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + # --------------- # test noop #---------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 84d77795c..b2fded136 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor): def visit_FunctionDef(self, node, inline=False, arg_values=None): arg_names, kwarg_names = self.visit(node.args) + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i-1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + self.visit(init_node) # store keyword arguments in local scope self.lscope[kwarg_names] = self.kwargs # initialize function @@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor): iterator = self.visit(node.iter.func) if iterator != self.builtins['range']: raise RuntimeError('Only `range` iterator currently supported') + # static for loops: all iterator arguments are constexpr + iter_args = [self.visit(arg) for arg in node.iter.args] + is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) + if is_static: + st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + iter_args = [arg.value for arg in iter_args] + range = iterator(*iter_args) + if len(range) <= 10: + for i in iterator(*iter_args): + self.lscope[node.target.id] = triton.language.constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return # create nodes st_target = ast.Name(id=node.target.id, ctx=ast.Store()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) @@ -483,6 +508,7 @@ class CompilationError(Exception): self.message += '\n' + ' ' * node.col_offset + '^' self.message += '\n Error: ' + str(err) super().__init__(self.message) + self.args = (src, node, err) class OutOfResources(Exception): @@ -491,6 +517,7 @@ class OutOfResources(Exception): f'Required: {required}'\ f'Hardware limit: {limit}' super().__init__(self.message) + self.args = (required, limit, name) class Kernel: @@ -805,7 +832,10 @@ class JITFunction: # information of wrapped function self.fn = fn self.module = fn.__module__ - self.arg_names = inspect.getfullargspec(fn).args + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.arg_defaults = [v.default for v in signature.parameters.values()] + self.version = version self.src = textwrap.dedent(inspect.getsource(fn)) self.do_not_specialize = [] if do_not_specialize is None else\ @@ -829,7 +859,7 @@ class JITFunction: if not hasattr(self.fn, 'hash'): dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) dependencies_finder.visit(self.parse()) - self.fn.hash = dependencies_finder.ret + self.fn.hash = dependencies_finder.ret + version_key() return self.fn.hash # we do not parse `src` in the constructor because @@ -848,6 +878,7 @@ class JITFunction: lscope = generator.lscope.copy() values = generator.module.get_values().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ + generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) generator.gscope = gscope generator.lscope = lscope diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 9bb29588a..a831af487 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -7,98 +7,56 @@ from . import core as tl # 2. multiply_low_high is currently inefficient. # 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float +PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox -@triton.jit -def PHILOX_KEY_A(): - # 0x9E3779B9 - return -1640531527 - - -@triton.jit -def PHILOX_KEY_B(): - # 0xBB67AE85 - return -1150833019 - - -@triton.jit -def PHILOX_ROUND_A(): - # 0xD2511F53 - return -766435501 - - -@triton.jit -def PHILOX_ROUND_B(): - # 0xCD9E8D57 - return -845247145 +# ------------------- +# randint +# ------------------- @triton.jit def hacky_to_uint64(x): return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) - @triton.jit -def single_round(c0, c1, c2, c3, k0, k1): - A = PHILOX_ROUND_A() - B = PHILOX_ROUND_B() - _c0, _c2 = c0, c2 - c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 - c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 - c1 = B * _c2 - c3 = A * _c0 +def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 + c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B return c0, c1, c2, c3 - @triton.jit -def raise_key(k0, k1): - return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B()) - -@triton.jit -def philox_f(c0, c1, c2, c3, k0, k1): - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - return c0, c1, c2, c3 - - - -@triton.jit -def uint32_to_uniform_float(x): +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). - This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly - covers all the possible values it can take. + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. """ - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. - x = tl.where(x < 0, -x - 1, x) - return x * max + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret @triton.jit -def pair_uniform_to_normal(u1, u2): - """Box-Muller transform""" - u1 = tl.maximum(1.0e-7, u1) - th = 6.283185307179586 * u2 - r = tl.sqrt(-2.0 * tl.log(u1)) - return r * tl.cos(th), r * tl.sin(th) - - -@triton.jit -def randint4x(seed, offset): +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns four blocks of random :code:`int32`. @@ -114,27 +72,26 @@ def randint4x(seed, offset): seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32) - return philox_f(offset, z, z, z, seed_lo, seed_hi) + return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) +# ------------------- +# rand +# ------------------- + @triton.jit -def randint(seed, offset): +def uint32_to_uniform_float(x): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns a single - block of random :code:`int32`. - - If you need multiple streams of random numbers, - using `randint4x` is likely to be faster than calling `randint` 4 times. - - :param seed: The seed for generating random numbers. - :param offsets: The offsets to generate random numbers for. + Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). + This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly + covers all the possible values it can take. """ - ret, _, _, _ = randint4x(seed, offset) - return ret - + max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + x = tl.where(x < 0, -x - 1, x) + return x * max @triton.jit -def rand(seed, offset): +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`U(0, 1)` @@ -142,28 +99,11 @@ def rand(seed, offset): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - source = randint(seed, offset) + source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) - @triton.jit -def randn(seed, offset): - """ - Given a :code:`seed` scalar and an :code:`offset` block, - returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` - - :param seed: The seed for generating random numbers. - :param offsets: The offsets to generate random numbers for. - """ - i1, i2, _, _ = randint4x(seed, offset) - u1 = uint32_to_uniform_float(i1) - u2 = uint32_to_uniform_float(i2) - n1, _ = pair_uniform_to_normal(u1, u2) - return n1 - - -@triton.jit -def rand4x(seed, offsets): +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offsets` block, returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)` @@ -171,16 +111,42 @@ def rand4x(seed, offsets): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - i1, i2, i3, i4 = randint4x(seed, offsets) + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) u1 = uint32_to_uniform_float(i1) u2 = uint32_to_uniform_float(i2) u3 = uint32_to_uniform_float(i3) u4 = uint32_to_uniform_float(i4) return u1, u2, u3, u4 +# ------------------- +# randn +# ------------------- @triton.jit -def randn4x(seed, offset): +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = tl.sqrt(-2.0 * tl.log(u1)) + return r * tl.cos(th), r * tl.sin(th) + +@triton.jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + +@triton.jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` @@ -188,7 +154,7 @@ def randn4x(seed, offset): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - u1, u2, u3, u4 = rand4x(seed, offset) + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) n1, n2 = pair_uniform_to_normal(u1, u2) n3, n4 = pair_uniform_to_normal(u3, u4) return n1, n2, n3, n4 From 8ec9f037bb9617b9ac4abdc1086a4e3b279afe46 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 30 Nov 2021 22:00:56 -0800 Subject: [PATCH 016/215] [BACKEND/CODE_GEN] Fixed float32 matmul problem (#380) --- lib/codegen/selection/generator.cc | 15 +++++++++------ python/test/unit/language/test_core.py | 15 ++++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index eeabb6841..3c4fae3d8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -788,7 +788,6 @@ void generator::visit_cat_inst(ir::cat_inst* x) { for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; } -// std::cout << "!" << std::endl; } @@ -1660,13 +1659,17 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va std::map ret = vals_[D]; std::map, Value*> has, hbs; + auto ord = layout_c->get_order(); for(unsigned k = 0; k < NK; k++){ int z = 0; - for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0)) - for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1)) - for(unsigned mm = 0; mm < layout_c->nts(0); mm++) - for(unsigned nn = 0; nn < layout_c->nts(1); nn++) - { + for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1])) + for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0])) + for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++) + for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){ + unsigned m = (ord[0] == 1) ? i : j; + unsigned n = (ord[0] == 1) ? j : i; + unsigned mm = (ord[0] == 1) ? ii : jj; + unsigned nn = (ord[0] == 1) ? jj : ii; if(has.find({m + mm, k}) == has.end()){ Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); Value* va = load(pa); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 98c8c34fa..6359857fe 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'): # test dot # --------------- -@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, device='cuda'): +@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) +def test_dot(epilogue, dtype=torch.float32, device='cuda'): torch.manual_seed(0) # triton kernel @triton.jit @@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'): tl.store(Zs, z) # input M, N, K = 64, 64, 32 - x = triton.testing.random((M, K), dtype=torch.float16, device=device) - y = triton.testing.random((K, N), dtype=torch.float16, device=device) + x = triton.testing.random((M, K), dtype=dtype, device=device) + y = triton.testing.random((K, N), dtype=dtype, device=device) # triton result - z = triton.testing.random((M, N), dtype=torch.float16, device=device) + z = triton.testing.random((M, N), dtype=dtype, device=device) z_tri = z.clone() + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), y, y.stride(0), y.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), @@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'): z_ref += z[0,:][None, :] z_ref = z_ref.to(torch.float16) # compare - ptx = pgm.asm['ptx'] - # print(ptx) triton.testing.assert_almost_equal(z_tri, z_ref) # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx From f23bf55f15a2dbbde83200bd002a1bcc144a2c22 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 3 Dec 2021 13:01:01 -0800 Subject: [PATCH 017/215] [RUNTIME] release the gil on launch (#383) --- python/src/triton.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 26c233287..7b6c1ce81 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -272,10 +272,18 @@ void init_triton_runtime(py::module &&m) { CU_LAUNCH_PARAM_END }; uint64_t _stream = PyLong_AsLong(stream.ptr()); - if(grid_0*grid_1*grid_2 > 0) + if(grid_0*grid_1*grid_2 > 0) { + // release the gil in case the enqueue blocks + // cuda will block if too many ops are enqueued + Py_BEGIN_ALLOW_THREADS + + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); + + Py_END_ALLOW_THREADS + } return bin; }); From e31b9b4e660b037d5680ef7b9924fd012566a22b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 9 Dec 2021 13:21:22 -0800 Subject: [PATCH 018/215] [RUNTIME] Better support for `None` (#387) * regression test fails but it doesn't make sense to me. --- python/src/triton.cc | 13 +++++++++---- python/triton/code_gen.py | 28 +++++++++++++++++++--------- python/triton/language/core.py | 7 +++++-- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 7b6c1ce81..aca171dae 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(PyLong_Check(arg_ptr)){ int overflow; long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + if(specialize && (value == 1)){ + cache_key += '1'; + continue; + } // long and int have different kernels if(!overflow & (std::abs(value) <= 0xffffffff)){ cache_key += 'I'; @@ -147,10 +151,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(!specialize) continue; // values equal to 1 are specialized - if(value == 1) - cache_key += '1'; - else - cache_key += 'x'; + cache_key += 'x'; // values divisible by small powers of 2 are specialized cache_key += pow2_divisor(value); continue; @@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } std::string ty_str = arg.attr("__class__").attr("__name__").cast(); + if(ty_str == "NoneType"){ + cache_key += "None"; + continue; + } std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; throw std::runtime_error(err_msg); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b2fded136..e4b7d9f0f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor): else: fn = self.module.get_or_insert_function(node.name, self.prototype) arg_values = [] + idx = 0 for i, arg_name in enumerate(arg_names): if i in self.constants: cst = self.constants[i] @@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor): arg_values.append(cst) else: if i in self.attributes: - is_ptr = fn.args[i].type.is_ptr() + is_ptr = fn.args[idx].type.is_ptr() attr = 'aligned' if is_ptr else 'multiple_of' attr = getattr(_triton.ir.attribute_kind, attr) attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(i + 1, attr) - fn.args[i].name = arg_name - arg_values.append(fn.args[i]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(fn.args[idx]) + idx += 1 + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) @@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor): lhs = lhs.value if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value + if type(node.ops[0]) == ast.Is: + return triton.language.constexpr(lhs is rhs) + if type(node.ops[0]) == ast.IsNot: + return triton.language.constexpr(lhs is not rhs) fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor): ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__', - ast.Is: '__eq__', - ast.IsNot: '__ne__', }[type(node.ops[0])] if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) @@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor): def visit_UnaryOp(self, node): op = self.visit(node.operand) + if type(node.op) == ast.Not: + assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" + return triton.language.constexpr(not op) if isinstance(op, triton.language.core.constexpr): op = op.value + # print(op) fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -592,11 +601,11 @@ class Kernel: self.fn = fn def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): - wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] + fn_args = [arg for i, arg in enumerate(wargs) if i not in constants] + arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args] ret_type = _triton.ir.type.get_void(context) prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR @@ -629,8 +638,9 @@ class Kernel: if isinstance(a, int) and i not in self.fn.do_not_specialize} # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7875a30f6..d7240fcf8 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -35,6 +35,9 @@ def _patch(fn): builder = args[-1] assert isinstance(builder, ir.builder) args = [_to_ir(x, builder) for x in args] + # for i, arg in enumerate(args): + # if arg is None: + # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} ret = fn(*args, **kwargs) if isinstance(ret, tuple): @@ -77,7 +80,7 @@ class pointer_dtype: def handle(self, builder): return ir.type.make_ptr(self.element_ty.handle(builder), 1) - +# scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8) int16 = dtype(ir.type.get_int16) @@ -88,7 +91,7 @@ float16 = dtype(ir.type.get_fp16) bfloat16 = dtype(ir.type.get_bf16) float32 = dtype(ir.type.get_fp32) float64 = dtype(ir.type.get_fp64) - +# pointer types pi32_t = pointer_dtype(int32) From 9def2424abef1b0c15bdfa5ca78a87a2b246f037 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 9 Dec 2021 15:14:06 -0800 Subject: [PATCH 019/215] [RUNTIME] Fix typo in IfExp --- python/triton/code_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e4b7d9f0f..deede2530 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -279,7 +279,7 @@ class CodeGenerator(ast.NodeVisitor): def visit_IfExp(self, node): cond = self.visit(node.test) - if cond: + if cond.value: return self.visit(node.body) else: return self.visit(node.orelse) From e575ae3443fcefb6f04b954893696daaaf91bde1 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 10 Dec 2021 15:19:20 -0800 Subject: [PATCH 020/215] [FRONTEND] Minor accumulated style and warning fixes (#388) - Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes. --- lib/codegen/selection/generator.cc | 6 +++--- python/requirements-test.txt | 1 + python/setup.cfg | 2 +- python/src/pybind11/detail/internals.h | 2 ++ python/test/unit/language/test_core.py | 5 +++-- python/test/unit/language/test_random.py | 5 ++--- python/triton/code_gen.py | 14 +++++++++----- python/triton/language/__init__.py | 5 ++--- python/triton/language/core.py | 3 +++ python/triton/language/random.py | 4 ++-- 10 files changed, 28 insertions(+), 19 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 3c4fae3d8..d5c5c4902 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() #define f32_ty builder_->getFloatTy() -#define i8_ty builder_->getInt8Ty() +#define i8_ty builder_->getInt8Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) @@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) { case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); - case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); - case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); + case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); + case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 4a1b49122..48b6d3be3 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1 +1,2 @@ +pytest scipy >= 1.7.1 diff --git a/python/setup.cfg b/python/setup.cfg index 224a77957..08aedd7e6 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -1,2 +1,2 @@ [metadata] -description-file = README.md \ No newline at end of file +description_file = README.md diff --git a/python/src/pybind11/detail/internals.h b/python/src/pybind11/detail/internals.h index f1dd38764..4f25759d3 100644 --- a/python/src/pybind11/detail/internals.h +++ b/python/src/pybind11/detail/internals.h @@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() { auto *&internals_ptr = *internals_pp; internals_ptr = new internals(); #if defined(WITH_THREAD) + #if PY_VERSION_HEX < 0x03090000 PyEval_InitThreads(); + #endif PyThreadState *tstate = PyThreadState_Get(); #if PY_VERSION_HEX >= 0x03070000 internals_ptr->tstate = PyThread_tss_alloc(); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6359857fe..e85c399b3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): ('float32', 'int32', True) ]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) + x0 = 43 if dtype_x.startswith('int') else 43.5 + x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) # triton kernel @triton.jit @@ -665,4 +666,4 @@ def test_noop(device='cuda'): def kernel(x): pass x = triton.testing.random((1,), dtype=torch.int32, device=device) - kernel[(1, )](x) \ No newline at end of file + kernel[(1, )](x) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 4c1261f1d..4d4501556 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -74,9 +74,8 @@ class CustomPhilox4x: return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) def _raise_key(self, key): - ret0 = key[0] + self._config.PHILOX_KEY_A - ret1 = key[1] + self._config.PHILOX_KEY_B - return np.array([ret0, ret1], dtype=self._dtype) + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) def random_raw(self): counter = self._counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index deede2530..96948e360 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -10,6 +10,7 @@ import os import pickle import subprocess import os +import warnings from .tools.disasm import extract import torch import triton @@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor): def visit(self, node): if node is not None: self.last_node = node - return super().visit(node) + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) + return super().visit(node) def generic_visit(self, node): typename = type(node).__name__ @@ -512,12 +517,11 @@ class LoadedBinary: class CompilationError(Exception): - def __init__(self, src, node, err): + def __init__(self, src, node): self.message = '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' - self.message += '\n Error: ' + str(err) super().__init__(self.message) - self.args = (src, node, err) + self.args = (src, node) class OutOfResources(Exception): @@ -618,7 +622,7 @@ class Kernel: node = generator.last_node if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e - raise CompilationError(self.fn.src, node, e) + raise CompilationError(self.fn.src, node) from e # Compile to machine code if torch.version.hip is None: backend = _triton.runtime.backend.CUDA diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index b96260c51..2f3f4ea05 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,3 @@ -from . import core -from . import random +from . import core, random from .core import * -from .random import * +from .random import rand, randint, randint4x, randn diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d7240fcf8..55b5bc0d9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -72,6 +72,9 @@ class dtype: ctx = builder.context return self.init(ctx) + def __str__(self): + return f"dtype({self.init.__name__})" + class pointer_dtype: def __init__(self, element_ty): diff --git a/python/triton/language/random.py b/python/triton/language/random.py index a831af487..e1ac3c30a 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2): def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, - returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. @@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, - returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. From 558555630f8633547821389c57b2b0bb4d650c1b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 16 Dec 2021 17:55:35 -0800 Subject: [PATCH 021/215] [FRONTEND] Added xor_sum --- include/triton/ir/dispatch.h | 1 + include/triton/ir/instructions.h | 3 ++- lib/codegen/selection/generator.cc | 2 ++ lib/ir/dispatch.cc | 7 +++++++ python/src/triton.cc | 1 + python/triton/language/core.py | 5 +++++ 6 files changed, 18 insertions(+), 1 deletion(-) diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 0c8295948..b6cc7dcc1 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -89,6 +89,7 @@ struct dispatch{ static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); // math static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index fdb2fd411..28ff9f3d6 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -788,7 +788,8 @@ class reduce_inst: public builtin_inst { public: enum op_t{ ADD, SUB, MAX, MIN, - FADD, FSUB, FMAX, FMIN + FADD, FSUB, FMAX, FMIN, + XOR }; private: diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d5c5c4902..a838bfa5a 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1890,6 +1890,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FSUB: return fsub(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); + case ir::reduce_inst::XOR: return xor_(x, y); default: throw std::runtime_error("unreachable"); } }; @@ -1904,6 +1905,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; + case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break; default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 811e5c819..477f4dce0 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -714,6 +714,13 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); } +ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (!scalar_ty->is_integer_ty()) + throw semantic_error("xor_sum only supported for integers"); + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); +} + //===----------------------------------------------------------------------===// // Math diff --git a/python/src/triton.cc b/python/src/triton.cc index aca171dae..92df2ae27 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -520,6 +520,7 @@ void init_triton_frontend(py::module &&m) { m.def("min", &ir::dispatch::min, ret::reference); m.def("max", &ir::dispatch::max, ret::reference); m.def("sum", &ir::dispatch::sum, ret::reference); + m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); // math m.def("umulhi", &ir::dispatch::umulhi, ret::reference); m.def("exp", &ir::dispatch::exp, ret::reference); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 55b5bc0d9..5db77efdc 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -719,6 +719,11 @@ def min(input, axis, _builder=None): def sum(input, axis, _builder=None): return frontend.sum(input, axis, _builder) +@builtin +@_add_reduction_docstr("xor sum") +def xor_sum(input, axis, _builder=None): + return frontend.xor_sum(input, axis, _builder) + # ----------------------- # Internal for debugging From e0b92c138090e67898a8f55d9197ef1966cb8dbe Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 16 Dec 2021 18:37:51 -0800 Subject: [PATCH 022/215] [FRONTEND] Reverted `from .random import *`. There are still some namespace errors in the Triton frontend apparently --- python/triton/language/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 2f3f4ea05..a7f341f16 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,3 +1,3 @@ from . import core, random from .core import * -from .random import rand, randint, randint4x, randn +from .random import * From eb077fc993787d7473c33eb8c345fa4cc6745ede Mon Sep 17 00:00:00 2001 From: Victor Date: Thu, 16 Dec 2021 22:09:52 -0800 Subject: [PATCH 023/215] [RUNTIME] fixed NVidia DLL names on Windows (#392) --- lib/driver/dispatch.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 4059ac235..9e2aca432 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -91,9 +91,13 @@ void* dispatch::fname ## _; bool dispatch::cuinit(){ if(cuda_==nullptr){ + #ifdef _WIN32 + cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY); + #else cuda_ = dlopen("libcuda.so", RTLD_LAZY); if(!cuda_) cuda_ = dlopen("libcuda.so.1", RTLD_LAZY); + #endif if(!cuda_) throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH."); } @@ -176,8 +180,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) * NVML * ------------------- */ bool dispatch::nvmlinit(){ + #ifdef _WIN32 + if(nvml_==nullptr) + nvml_ = dlopen("nvml.dll", RTLD_LAZY); + #else if(nvml_==nullptr) nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY); + #endif nvmlReturn_t (*fptr)(); nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2"); *reinterpret_cast(&fptr) = nvmlInit_v2_; From e0628129692a9ac29267c75ac47c7f0e83c7670a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 17 Dec 2021 12:44:47 -0800 Subject: [PATCH 024/215] [CODEGEN] Disabled peephole for masked load + select -- masked_load doesn't work as expected when vectorized --- lib/codegen/transform/peephole.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index ac0220ebc..f7ebdad80 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -284,7 +284,8 @@ void peephole::run(ir::module &mod) { // was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); - was_modified = was_modified || rewrite_select_masked_load(i, builder); + // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD +// was_modified = was_modified || rewrite_select_masked_load(i, builder); was_modified = was_modified || rewrite_cvt_layout(i, builder); if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) was_modified = was_modified || rewrite_load_to_shared(i, builder); From 4e93b41c528a84660dd42f96beb278fcd84948d2 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 17 Dec 2021 18:06:21 -0800 Subject: [PATCH 025/215] [GENERAL] Some minor fixups (#393) * [RUNTIME] Now displaying error message when generated PTX is invalid * [CODEGEN] Now converting `if` condition to bool implicitly --- lib/driver/llvm.cc | 136 +++++++++++++++++++++----------------- python/triton/code_gen.py | 5 +- 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index f3c76ce77..db64aa73b 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -158,12 +158,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ } std::string ptx_to_cubin(const std::string& ptx, int cc) { - std::string ptxas = "ptxas"; std::string version; - int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; - if(!use_system_ptxas) - return ""; - + // search pathes for ptxas + std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; + std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); + if(!triton_ptxas.empty()) + ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); + // see what path for ptxas are valid + std::vector working_ptxas; + for(std::string prefix: ptxas_prefixes){ + std::string ptxas = prefix + "ptxas"; + bool works = tools::exec(ptxas + " --version 2>&1", version) == 0; + if(works) + working_ptxas.push_back(ptxas); + } + // error if no working ptxas was found + if(working_ptxas.empty()) + throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" + " but a working version could not be found."); + std::string ptxas = working_ptxas.front(); // compile ptx with ptxas char _fsrc[] = "/tmp/triton_k_XXXXXX"; char _flog[] = "/tmp/triton_l_XXXXXX"; @@ -180,6 +193,11 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { int err; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); + if(err != 0){ + std::ifstream _log(_flog); + std::string log(std::istreambuf_iterator(_log), {}); + throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); + } CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); std::string cubin(std::istreambuf_iterator(_cubin), {}); @@ -191,62 +209,62 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { return cubin; } -CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { - // JIT compile source-code - try{ - // use ptxas if present in PATH. Otherwise, use JIT from the driver - std::string ptxas = "ptxas"; - std::string version; - int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; +//CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { +// // JIT compile source-code +// try{ +// // use ptxas if present in PATH. Otherwise, use JIT from the driver +// std::string ptxas = "ptxas"; +// std::string version; +// int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; - // Use PTXAS via system call - if(use_system_ptxas){ - // compile ptx with ptxas - char _fsrc[] = "/tmp/triton_k_XXXXXX"; - char _flog[] = "/tmp/triton_l_XXXXXX"; - mkstemp(_fsrc); - mkstemp(_flog); - std::string fsrc = _fsrc; - std::string flog = _flog; - std::string fbin = fsrc + ".o"; - const char* _fbin = fbin.c_str(); - std::ofstream ofs(fsrc); - ofs << ptx; - ofs.close(); - std::string cmd; - int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; - err = system(cmd.c_str()); - CUmodule ret; - std::ifstream _cubin(_fbin, std::ios::binary ); - std::string cubin(std::istreambuf_iterator(_cubin), {}); - _cubin.close(); - dispatch::cuModuleLoadData(&ret, cubin.c_str()); - unlink(_fsrc); - unlink(_flog); - unlink(_fbin); - return ret; - } +// // Use PTXAS via system call +// if(use_system_ptxas){ +// // compile ptx with ptxas +// char _fsrc[] = "/tmp/triton_k_XXXXXX"; +// char _flog[] = "/tmp/triton_l_XXXXXX"; +// mkstemp(_fsrc); +// mkstemp(_flog); +// std::string fsrc = _fsrc; +// std::string flog = _flog; +// std::string fbin = fsrc + ".o"; +// const char* _fbin = fbin.c_str(); +// std::ofstream ofs(fsrc); +// ofs << ptx; +// ofs.close(); +// std::string cmd; +// int err; +// cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; +// err = system(cmd.c_str()); +// CUmodule ret; +// std::ifstream _cubin(_fbin, std::ios::binary ); +// std::string cubin(std::istreambuf_iterator(_cubin), {}); +// _cubin.close(); +// dispatch::cuModuleLoadData(&ret, cubin.c_str()); +// unlink(_fsrc); +// unlink(_flog); +// unlink(_fbin); +// return ret; +// } - // Use PTXAS included in driver - CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, - CU_JIT_LOG_VERBOSE}; - unsigned int errbufsize = 8192; - unsigned int logbufsize = 8192; - char _err[errbufsize]; - char _log[logbufsize]; - void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; - CUmodule ret; - dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval); - return ret; - } - catch(exception::cuda::invalid_ptx const &){ - std::cout << ptx << std::endl; - std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; - throw; - } -} +// // Use PTXAS included in driver +// CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, +// CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, +// CU_JIT_LOG_VERBOSE}; +// unsigned int errbufsize = 8192; +// unsigned int logbufsize = 8192; +// char _err[errbufsize]; +// char _log[logbufsize]; +// void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; +// CUmodule ret; +// dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval); +// return ret; +// } +// catch(exception::cuda::invalid_ptx const &){ +// std::cout << ptx << std::endl; +// std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; +// throw; +// } +//} /* ------------------------ */ // HIP // diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 96948e360..d00a9d50c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -248,7 +248,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if self.is_triton_object(cond): + if isinstance(cond, triton.language.block): + cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None @@ -273,6 +274,8 @@ class CodeGenerator(ast.NodeVisitor): self.module.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: + if isinstance(cond, triton.language.constexpr): + cond = cond.value if cond: self.visit_compound_statement(node.body) else: From fa62b4a8f6efa3f976e40afe346ed8648272ff61 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 17 Dec 2021 20:11:45 -0800 Subject: [PATCH 026/215] [FRONTEND] better stringification (#394) - Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace. - Better `__str__` on `TensorWrapper`, `dtype`, and `block`. --- include/triton/ir/type.h | 4 ++-- python/triton/code_gen.py | 7 +++++-- python/triton/language/core.py | 14 +++++++++++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 33c74f245..7fb14877a 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -74,8 +74,8 @@ public: bool is_fp8_ty() const { return id_ == FP8TyID; } bool is_fp16_ty() const { return id_ == FP16TyID; } bool is_bf16_ty() const { return id_ == BF16TyID; } - bool is_fp32_ty() const { return id_ == FP32TyID; } - bool is_fp64_ty() const { return id_ == FP64TyID; } + bool is_fp32_ty() const { return id_ == FP32TyID; } + bool is_fp64_ty() const { return id_ == FP64TyID; } bool is_label_ty() const { return id_ == LabelTyID;} bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d00a9d50c..688508265 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -521,10 +521,10 @@ class LoadedBinary: class CompilationError(Exception): def __init__(self, src, node): - self.message = '\n'.join(src.split('\n')[:node.lineno]) + self.message = f'at {node.lineno}:{node.col_offset}:\n' + self.message += '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' super().__init__(self.message) - self.args = (src, node) class OutOfResources(Exception): @@ -1085,6 +1085,9 @@ class TensorWrapper: def data_ptr(self): return self.base.data_ptr() + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + def reinterpret(tensor, dtype): return TensorWrapper(tensor, dtype) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5db77efdc..1c28cdef7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -68,12 +68,20 @@ class dtype: def __init__(self, init): self.init = init + @property + def name(self) -> str: + # The init functions are named something like 'get_int8'. Strip the prefix. + nom = self.init.__name__ + prefix = 'get_' + assert nom.startswith(prefix) + return nom[len(prefix):] + def handle(self, builder): ctx = builder.context return self.init(ctx) def __str__(self): - return f"dtype({self.init.__name__})" + return self.name class pointer_dtype: @@ -131,6 +139,10 @@ class block: # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) + def __str__(self) -> str: + # ex. "float32[3,4]" + return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + @builtin def __add__(self, other, _builder=None): return frontend.add(self, other, _builder) From 4a8953efa349569a298c0011997f89cc8abe7222 Mon Sep 17 00:00:00 2001 From: daadaada Date: Sun, 19 Dec 2021 10:03:22 +0800 Subject: [PATCH 027/215] [FRONTEND] Replace the legacy print call in triton.cc with the SlotTracker-based one. (#396) The legacy print call will assign names (e.g., %10) to values, which can be undesirable in some cases. --- python/src/triton.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 92df2ae27..01ad402aa 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -440,7 +440,7 @@ void init_triton_codegen(py::module &&m) { // record asm as we generate asm_map_t asm_map; std::ostringstream ttir; - ir::print(ir, ttir); + ir.print(ttir); asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; if(backend == CUDA) From 5cdb948c059480b3463f6d4cf593e675f46696b9 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Tue, 21 Dec 2021 09:46:05 -0800 Subject: [PATCH 028/215] [FRONTEND] signed-integer math fixes and testing (#395) - Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway. - Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now. - When testing signed integer operations, use random numbers from the full range of the type. - Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type. - Fix a bad `CompilationError` invocation. - Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8. --- lib/ir/dispatch.cc | 63 +++++++++++++------------- python/test/unit/language/test_core.py | 59 ++++++++++++++++++++---- python/triton/code_gen.py | 5 +- python/triton/testing.py | 13 ++++-- 4 files changed, 93 insertions(+), 47 deletions(-) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 477f4dce0..c616b2fd4 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -33,21 +33,28 @@ ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ return a_rank > b_rank ? a_ty : b_ty; } -ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){ +enum class DivOrMod { NO, YES }; + +ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { context &ctx = a_ty->get_context(); // 1) if one operand is double, the other is implicitly // converted to double - if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) + if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) return type::get_fp64_ty(ctx); // 2) if one operand is float, the other is implicitly // converted to float - if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) + if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) return type::get_fp32_ty(ctx); - // 3 ) if one operand is half, the other is implicitly - // converted to half - if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) - return type::get_fp16_ty(ctx); - if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) + // 3 ) if one operand is half, the other is implicitly converted to half + // unless we're doing / or %, which do not exist natively in PTX for fp16. + if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { + if (div_or_mod == DivOrMod::YES) { + return type::get_fp32_ty(ctx); + } else { + return type::get_fp16_ty(ctx); + } + } + if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) throw_unreachable("augment_types"); // 4 ) both operands are integer and undergo // integer promotion @@ -78,7 +85,7 @@ void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, - bool arithmetic_check = true){ + bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { // implicit broadcasting std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); // implicit typecasting @@ -86,8 +93,8 @@ void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* buil ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); - if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){ - ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty); + if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { + ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); lhs = dispatch::cast(lhs, ret_sca_ty, builder); rhs = dispatch::cast(rhs, ret_sca_ty, builder); } @@ -140,7 +147,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde } ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, false); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); // float / int @@ -168,7 +175,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu } ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ - binary_op_type_checking(input, other, builder, false, false, false); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ @@ -181,7 +188,7 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b } ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); // float % int if (scalar_ty->is_floating_point_ty()) @@ -193,51 +200,43 @@ ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builde } -void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){ +void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, false); ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) throw_incompatible_types(input_sca_ty, other_sca_ty); - // for some reason pytorch assigns the result of binary op to have the type of the lhs... - if(force_lhs_type){ - if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); - } - else{ - if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) - input = dispatch::cast(input, other_sca_ty, builder); - else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); - } - + if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) + input = dispatch::cast(input, other_sca_ty, builder); + else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) + other = dispatch::cast(other, input_sca_ty, builder); } ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_and(input, other); } ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_or(input, other); } ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, true); + bitwise_op_type_checking(input, other, builder); return builder->create_xor(input, other); } ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, false); + bitwise_op_type_checking(input, other, builder); return builder->create_lshr(input, other); } ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder, false); + bitwise_op_type_checking(input, other, builder); return builder->create_shl(input, other); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e85c399b3..785ca49ac 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): triton.testing.assert_almost_equal(z_ref, z_tri) -def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'): +def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -82,12 +82,12 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) - y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device) + x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17) + y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144) if mode_x == 'nan': x[:] = float('nan') if mode_y == 'nan': y[:] = float('nan') # reference result - z_ref = eval(expr) + z_ref = eval(expr if torch_expr is None else torch_expr) # triton result z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) @@ -95,17 +95,56 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) +def _fake_fmod(x, y): + """ + Triton % (for both integers and floats) has the same semantics as torch + fmod, but torch fmod doesn't work on integers until torch 1.8. + `_fake_fmod` gives the same semantics but works on all versions of torch. + """ + z = torch.remainder(x, y) + return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ] + # --------------- # test binary ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ - (dtype_x, dtype_y, f' x {op} y') \ - for op in ['+', '-', '*', '/', '%'] \ - for dtype_x in dtypes \ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes for dtype_y in dtypes ]) -def test_bin_op(dtype_x, dtype_y, expr, device='cuda'): - _test_binary(dtype_x, dtype_y, expr, device=device) +def test_bin_op(dtype_x, dtype_y, op, device='cuda'): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: + # LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders. + torch_expr = '_fake_fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)' + else: + torch_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match='Arrays are not almost equal'): + _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + else: + _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + # --------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 688508265..e1091eff7 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -482,7 +482,8 @@ class CodeGenerator(ast.NodeVisitor): with warnings.catch_warnings(): # The ast library added visit_Constant and deprecated some other # methods but we can't move to that without breaking Python 3.6 and 3.7. - warnings.simplefilter("ignore", DeprecationWarning) + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 return super().visit(node) def generic_visit(self, node): @@ -905,7 +906,7 @@ class JITFunction: node = generator.last_node if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e - raise CompilationError(self.src, node, e) + raise CompilationError(self.src, node) from e # - when `.src` attribute is set, cache path needs # to be reinitialized diff --git a/python/triton/testing.py b/python/triton/testing.py index 08ad62580..051a8f378 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -89,14 +89,21 @@ def assert_allclose(x, y, tol=1e-2): assert allclose(x, y, tol) -def random(shape, dtype, device): - torch.manual_seed(0) +def random(shape, dtype, device, seed=0): + """ + Override the seed in tests if you're calling this function twice and don't + want the same result for both calls. + """ + torch.manual_seed(seed) if isinstance(shape, int): shape = (shape, ) if dtype == torch.bool: return torch.randint(0, 2, shape, dtype=dtype, device=device) if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - return torch.randint(1, 32, shape, dtype=dtype, device=device) + iinfo = torch.iinfo(dtype) + x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device) + x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. + return x if dtype in [torch.float16, torch.float32, torch.float64]: return torch.normal(0, 1, shape, dtype=dtype, device=device) raise RuntimeError(f'Unknown dtype {dtype}') From 39d4bfed8391dfbc88c14db14c7bad673beaab08 Mon Sep 17 00:00:00 2001 From: daadaada Date: Wed, 22 Dec 2021 01:56:10 +0800 Subject: [PATCH 029/215] [OPS] Add performance model for gemm/gemv (#397) Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process. --- include/triton/codegen/transform/cts.h | 2 +- include/triton/codegen/transform/peephole.h | 2 +- include/triton/ir/instructions.h | 11 +- lib/codegen/pass.cc | 3 +- lib/codegen/transform/cts.cc | 2 +- lib/codegen/transform/pipeline.cc | 2 +- python/src/triton.cc | 35 ++++++ python/test/regression/test_performance.py | 6 +- python/triton/code_gen.py | 78 +++++++++++-- python/triton/ops/matmul.py | 34 ++++-- python/triton/ops/matmul_perf_model.py | 116 ++++++++++++++++++++ python/triton/testing.py | 25 +++++ 12 files changed, 289 insertions(+), 27 deletions(-) create mode 100644 python/triton/ops/matmul_perf_model.py diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h index dcc5f36c2..70fbc474b 100644 --- a/include/triton/codegen/transform/cts.h +++ b/include/triton/codegen/transform/cts.h @@ -33,4 +33,4 @@ private: } } -#endif +#endif \ No newline at end of file diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 1b015fb41..0e1ed222e 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -35,7 +35,7 @@ private: bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder); - + public: peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} void run(ir::module &mod); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 28ff9f3d6..699d22257 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -455,7 +455,7 @@ public: // masked load async class masked_load_async_inst: public load_inst { private: - std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); } + std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); @@ -728,12 +728,21 @@ public: class dot_inst: public builtin_inst { public: enum TransT { NoTrans, Trans }; + enum DataType { + FP8, FP16, BF16, TF32, FP32, + INT1, INT4, INT8, INT32, + UNKNOWN, + }; private: dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } bool is_prefetched_ = false; + DataType C_type_ = DataType::FP32; + DataType A_type_ = DataType::FP16; + DataType B_type_ = DataType::FP16; + public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 845e2e36d..d38d81a9c 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -52,7 +52,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC peephole.run(ir); dce.run(ir); pipeline.run(ir); - dce.run(ir); + dce.run(ir); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -85,6 +85,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index 2641dad53..c223d2413 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -94,4 +94,4 @@ void cts::run(ir::module &mod) { } } -} +} \ No newline at end of file diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index cc7835bbc..bc249841b 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -327,4 +327,4 @@ void pipeline::run(ir::module &mod) { } } -} +} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index 01ad402aa..ce56d9c26 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) { return bin; }); + m.def("cc", [](backend_t backend, uint64_t device) -> int { + if (backend == CUDA) { + CUdevice dev = (CUdevice)device; + int major = cuGetInfo(dev); + int minor = cuGetInfo(dev); + return major*10 + minor; + } + return -1; + }); + // query maximum shared memory m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) @@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) { return -1; }); + // query DRAM & L2 cache + m.def("memory_clock_rate", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + m.def("l2_cache_size", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + + // query clock rate (in kilohertz) + m.def("clock_rate", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + + m.def("num_sm", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + // enqueue m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel, uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index eff21fdfd..ce93786b8 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -25,7 +25,7 @@ def nvsmi(attrs): matmul_data = { # square (256 , 256 , 256 ) : {'v100': 0.027}, - (512 , 512 , 512 ) : {'v100': 0.141}, + (512 , 512 , 512 ) : {'v100': 0.158}, (1024, 1024, 1024 ) : {'v100': 0.466}, (2048, 2048, 2048 ) : {'v100': 0.680}, (4096, 4096, 4096 ) : {'v100': 0.831}, @@ -35,10 +35,10 @@ matmul_data = { (16 , 4096, 4096 ) : {'v100': 0.0883}, (16 , 8192, 8192 ) : {'v100': 0.101}, (64 , 1024, 1024 ) : {'v100': 0.073}, - (64 , 4096, 4096 ) : {'v100': 0.228}, + (64 , 4096, 4096 ) : {'v100': 0.270}, (64 , 8192, 8192 ) : {'v100': 0.360}, (1024, 64 , 1024 ) : {'v100': 0.0692}, - (4096, 64 , 4096 ) : {'v100': 0.223}, + (4096, 64 , 4096 ) : {'v100': 0.264}, (8192, 64 , 8192 ) : {'v100': 0.323}, # # deep reductions # (64 , 64 , 16384) : {'v100': 0.}, diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e1091eff7..2f6ddf3c1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -17,6 +17,10 @@ import triton import triton._C.libtriton.triton as _triton from filelock import FileLock import dbm +import tempfile +from typing import Optional, Dict +import time + class CodeGenerator(ast.NodeVisitor): @@ -508,6 +512,7 @@ class LoadedBinary: device) self.bin = bin self.asm = bin.asm + self.sass = '' self.module = module self.kernel = kernel self.device = device @@ -519,6 +524,19 @@ class LoadedBinary: self.bin.num_warps * 32, 1, 1, args, self.bin.shared_mem) + def get_sass(self, fun=None): + if self.sass: + return self.sass + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass + class CompilationError(Exception): def __init__(self, src, node): @@ -530,8 +548,8 @@ class CompilationError(Exception): class OutOfResources(Exception): def __init__(self, required, limit, name): - self.message = f'out of resource: {name}'\ - f'Required: {required}'\ + self.message = f'out of resource: {name}, '\ + f'Required: {required}, '\ f'Hardware limit: {limit}' super().__init__(self.message) self.args = (required, limit, name) @@ -727,7 +745,13 @@ class Launcher: class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero): + def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + ''' if not configs: self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: @@ -744,7 +768,16 @@ class Autotuner: args[i].zero_() self.hook = _hook self.arg_names = arg_names - + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'prune_num_stages_by' in prune_configs_by: + prune_num_stages_by = prune_configs_by['prune_num_stages_by'] + else: + perf_model, top_k, prune_num_stages_by = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.prune_num_stages_by = prune_num_stages_by + def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -768,13 +801,29 @@ class Autotuner: if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: + # prune configs + pruned_configs = self.configs + if self.prune_num_stages_by: + pruned_configs = self.prune_num_stages_by(self.configs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] + bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) \ - for config in self.configs} + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) self.hook(args) + self.configs_timings = timings config = self.cache[key] else: config = self.configs[0] + self.best_config = config if config.pre_hook != None: config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor): module = inspect.getmodule(func) if module and module.__name__.startswith('triton.'): return + if inspect.isbuiltin(func): + return if not hasattr(func, 'hash'): src = textwrap.dedent(inspect.getsource(func)) tree = ast.parse(src) @@ -957,8 +1008,16 @@ class Config: self.num_stages = num_stages self.pre_hook = pre_hook + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) -def autotune(configs, key, reset_to_zero=None): + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None): :type configs: list[triton.Config] :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] """ def decorator(fn): def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) + return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) fn.kernel_decorators.append(wrapper) return fn @@ -1023,7 +1086,6 @@ def heuristics(values): assert v not in meta meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) return kernel(*args, **meta) - return fun fn.kernel_decorators.append(wrapper) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index ae404b8d6..8b7299a8b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,15 +1,33 @@ import torch import triton.language as tl import triton +from .matmul_perf_model import * def init_to_zero(name): return lambda nargs: nargs[name].zero_() +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @triton.autotune( configs=[ + # basic configs for compute-bound matmuls triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), @@ -19,17 +37,13 @@ def init_to_zero(name): triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), - ], + ] + get_configs_io_bound(), key=['M', 'N', 'K'], + prune_configs_by={ + 'prune_num_stages_by' : prune_num_stages, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, ) @triton.jit def _kernel(A, B, C, M, N, K, diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py new file mode 100644 index 000000000..16667a7b1 --- /dev/null +++ b/python/triton/ops/matmul_perf_model.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton._C.libtriton.triton as _triton +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +import heapq + +def get_tensorcore_tflops(backend, device, num_ctas, num_warps): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) + return tflops + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, + M, N, K, + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, + debug=False, **kwargs +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + + num_cta_m = triton.cdiv(M, BLOCK_M) + num_cta_n = triton.cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2*M*N*K / (1024*1024*1024) # GOPS + tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + compute_ms = total_ops / tput + + # time to load data + num_sm = _triton.runtime.num_sm(backend, device) + active_cta_ratio = min(1, num_ctas/num_sm) + active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) + load_a_l2 = M*K*2*0.8*(num_cta_n-1) + load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) + load_b_l2 = N*K*2*0.8*(num_cta_m-1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) + # loading time in ms + load_ms = total_dram/dram_bw + total_l2/l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram /store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram/reduce_bw + # c.zero_() + zero_ms = M*N*2/(1024*1024)/store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + +def prune_num_stages(configs): + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + cc = _triton.runtime.cc(backend, device) + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if cc >= 80: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16) + mma_cycles = mmas/min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency/mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs \ No newline at end of file diff --git a/python/triton/testing.py b/python/triton/testing.py index 051a8f378..f274e808f 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,5 +1,6 @@ import torch import os +import triton._C.libtriton.triton as _triton from .code_gen import OutOfResources import subprocess import sys @@ -320,3 +321,27 @@ def perf_report(benchmarks): """ wrapper = lambda fn: Mark(fn, benchmarks) return wrapper + +def get_dram_gbps(backend=None, device=None): + ''' return DRAM bandwidth in GB/s ''' + # assert backend == CUDA + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) + bus_width = _triton.runtime.global_memory_bus_width(backend, device) + bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + return bw_gbps + +def get_max_tensorcore_tflops(backend, device): + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + # assume fp32 += fp16*fp16 + cc = _triton.runtime.cc(backend, device) + if cc < 80: + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + ops_per_sub_core = 512 + tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) + return tflops \ No newline at end of file From 2509124dd05d7eb5b16e1f5e4714e565f39145c1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 21 Dec 2021 14:31:51 -0800 Subject: [PATCH 030/215] [DRIVER] Fixed some issue with how ptxas is used (#399) Now using tmpnam and properly deleting temporaries when an exception is raised --- lib/driver/llvm.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index db64aa73b..7248d6cec 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -178,16 +178,14 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { " but a working version could not be found."); std::string ptxas = working_ptxas.front(); // compile ptx with ptxas - char _fsrc[] = "/tmp/triton_k_XXXXXX"; - char _flog[] = "/tmp/triton_l_XXXXXX"; - mkstemp(_fsrc); - mkstemp(_flog); - std::string fsrc = _fsrc; - std::string flog = _flog; + char _fsrc[L_tmpnam]; + char _flog[L_tmpnam]; + std::string fsrc = std::tmpnam(_fsrc); + std::string flog = std::tmpnam(_flog); std::string fbin = fsrc + ".o"; const char* _fbin = fbin.c_str(); std::ofstream ofs(fsrc); - ofs << ptx; + ofs << ptx << std::endl; ofs.close(); std::string cmd; int err; @@ -196,16 +194,18 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { if(err != 0){ std::ifstream _log(_flog); std::string log(std::istreambuf_iterator(_log), {}); + unlink(_fsrc); + unlink(_flog); throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); } CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); std::string cubin(std::istreambuf_iterator(_cubin), {}); _cubin.close(); - dispatch::cuModuleLoadData(&ret, cubin.c_str()); unlink(_fsrc); unlink(_flog); unlink(_fbin); + dispatch::cuModuleLoadData(&ret, cubin.c_str()); return cubin; } From a425f24d54fbac756879d97ddc75bebb88fd66be Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 21 Dec 2021 21:29:47 -0800 Subject: [PATCH 031/215] [FRONTEND] Better cache hook (#400) Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary. --- python/src/triton.cc | 26 ++++++++++++++++---------- python/test/unit/runtime/test_cache.py | 7 ++++--- python/triton/code_gen.py | 12 +++++++++++- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index ce56d9c26..b44ffbc27 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -117,7 +117,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; cache_key = func_key; + cache_key += "-" + std::to_string(num_warps); + cache_key += "-" + std::to_string(num_stages); + cache_key += "-"; for(int i = 0; i < len; i++){ + cache_key += "_"; py::int_ py_i = py::int_(i); bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); py::object arg = args[i]; @@ -127,19 +131,20 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(PyLong_Check(arg_ptr)){ int overflow; long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + // values equal to 1 are specialized if(specialize && (value == 1)){ - cache_key += '1'; + cache_key += "1"; continue; } // long and int have different kernels if(!overflow & (std::abs(value) <= 0xffffffff)){ - cache_key += 'I'; + cache_key += "int32"; params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); params_ptr += 4; } else{ - cache_key += 'L'; + cache_key += "int64"; params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); if(overflow){ unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); @@ -150,15 +155,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f } if(!specialize) continue; - // values equal to 1 are specialized - cache_key += 'x'; // values divisible by small powers of 2 are specialized + cache_key += "[multipleof("; cache_key += pow2_divisor(value); + cache_key += ")]"; continue; } // argument is `float` if(PyFloat_Check(arg_ptr)){ - cache_key += "f"; + cache_key += "float32"; float value = PyFloat_AsDouble(arg_ptr); params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); @@ -167,7 +172,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f } // argument is `bool` if(PyBool_Check(arg_ptr)){ - cache_key += "B"; + cache_key += "bool"; bool value = arg_ptr == Py_True ? true : false; std::memcpy(params_ptr, &value, 1); params_ptr += 1; @@ -176,7 +181,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f // argument is tensor if(py::hasattr(arg, "data_ptr")){ py::object data_ptr = arg.attr("data_ptr")(); - cache_key += "P"; long value = data_ptr.cast(); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); @@ -186,6 +190,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); + cache_key += "*"; + cache_key += "[multipleof("; + cache_key += pow2_divisor(value); + cache_key += ")]"; continue; } // argument is `constexpr` @@ -208,8 +216,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; throw std::runtime_error(err_msg); } - cache_key += std::to_string(num_warps); - cache_key += std::to_string(num_stages); params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); } diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a1c994241..3ad387f09 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -64,20 +64,21 @@ def reset_tmp_dir(): def test_reuse(): counter = 0 - def inc_counter(key, binary): + def inc_counter(key, binary, repr): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): - kernel[(1,)](x, 43, BLOCK=1024) + kernel[(1,)](x, 1, BLOCK=1024) assert counter == 1 + @pytest.mark.parametrize('mode', ['enable', 'disable']) def test_specialize(mode): counter = 0 - def inc_counter(key, binary): + def inc_counter(key, binary, repr): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 2f6ddf3c1..cacc39675 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -700,7 +700,17 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) if JITFunction.cache_hook is not None: - JITFunction.cache_hook(key=key, binary=binary) + name = self.fn.fn.__name__ + info = key.split('-')[-3:] + num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] + # make signature human-readable + arg_reprs = [] + for arg_name, arg_sig in zip(self.fn.arg_names, sig): + arg_reprs.append(f'{arg_name}: {arg_sig}') + # assemble the repr + arg_reprs = ", ".join(arg_reprs) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + JITFunction.cache_hook(key=key, binary=binary, repr=repr) self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) From d8fce83e7a7d54ba66b55415b828d40e4f322169 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 21 Dec 2021 22:14:06 -0800 Subject: [PATCH 032/215] [FRONTEND] Remade exception picklable --- python/triton/code_gen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cacc39675..8393f2b87 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -544,6 +544,8 @@ class CompilationError(Exception): self.message += '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' super().__init__(self.message) + # this is necessary to make CompilationError picklable + self.args = (src, node) class OutOfResources(Exception): From 985798f10127fe5c1e918c09e1889de51d65abfe Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Thu, 23 Dec 2021 17:01:17 -0800 Subject: [PATCH 033/215] add missing bfloat16 repr and improve assertions (#403) - `BF16TyID` was missing a repr implementation. - Throw a better exception on impossible casts. - Add a few assertions. Tested with a debug build. - Add `pointer_dtype.__str__` to aid kernel debugging. --- include/triton/ir/type.h | 4 ++-- lib/ir/dispatch.cc | 4 ++-- lib/ir/instructions.cc | 3 +-- python/src/functions.h | 2 +- python/src/triton.cc | 1 + python/test/unit/language/test_core.py | 5 ----- python/triton/language/core.py | 5 +++++ 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 7fb14877a..c9c07c4f1 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -131,6 +131,7 @@ public: case FP16TyID: return "f16"; case FP32TyID: return "f32"; case FP64TyID: return "f64"; + case BF16TyID: return "bf16"; case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; @@ -141,8 +142,7 @@ public: case BlockTyID: return tile_repr(); default: break; } - assert(false); - return ""; + throw std::logic_error("unknown type id '" + std::to_string(id_) + "'"); }; private: diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index c616b2fd4..c4e8ccafb 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -55,7 +55,7 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) } } if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) - throw_unreachable("augment_types"); + throw_unreachable("computation_type"); // 4 ) both operands are integer and undergo // integer promotion return integer_promote(a_ty, b_ty); @@ -493,7 +493,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - return throw_unreachable("cast"); + return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 32e7674c6..00d801616 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -232,6 +232,7 @@ icmp_inst::icmp_inst(type *ty, cmp_pred_t pred, icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_int_predicate(pred)); + assert(lhs->get_type() == rhs->get_type()); type *res_ty = make_cmp_result_type(lhs->get_type()); return new icmp_inst(res_ty, pred, lhs, rhs, name, next); } @@ -920,7 +921,5 @@ const constant_int* make_range::get_last() const { return last_; } - - } } diff --git a/python/src/functions.h b/python/src/functions.h index 0f5a5c42f..19f7e7eb9 100644 --- a/python/src/functions.h +++ b/python/src/functions.h @@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) { other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - throw_not_implemented("cast"); + throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } /*---------------------------------------------- diff --git a/python/src/triton.cc b/python/src/triton.cc index b44ffbc27..cec6fba94 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -187,6 +187,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params_ptr += 8; py::object dtype = arg.attr("dtype"); py::object repr = py::repr(dtype); + assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6)); const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; cache_key += std::string(start, len); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 785ca49ac..aa0e7430a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -453,11 +453,6 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): # --------------- # test permute # --------------- - -# --------------- -# test permute -# --------------- - @pytest.mark.parametrize("dtype, shape, perm", [(dtype, shape, perm) \ for dtype in ['float32']\ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1c28cdef7..e939319aa 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -86,11 +86,16 @@ class dtype: class pointer_dtype: def __init__(self, element_ty): + if not isinstance(element_ty, dtype): + raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty def handle(self, builder): return ir.type.make_ptr(self.element_ty.handle(builder), 1) + def __str__(self): + return f'pointer<{self.element_ty}>' + # scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8) From 3edc2633e9876005da2423f1463fe2e215f52f04 Mon Sep 17 00:00:00 2001 From: Noah Ziems Date: Wed, 29 Dec 2021 18:09:34 -0500 Subject: [PATCH 034/215] [TUTORIALS] Fix 01-vector-add.py typo (#406) --- python/tutorials/01-vector-add.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 4446cf6e9..b25698a4e 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -37,7 +37,7 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements - # Load x and y from DRAM, masking out any extar elements in case the input is not a + # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) From 03f1256f603cf3626e005a181636fb0a44fe8761 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 30 Dec 2021 22:33:24 -0800 Subject: [PATCH 035/215] [FRONTEND] Added `volatile` flag for load (#407) --- include/triton/ir/builder.h | 4 ++-- include/triton/ir/dispatch.h | 2 +- include/triton/ir/instructions.h | 17 ++++++++++++----- lib/codegen/selection/generator.cc | 6 +++++- lib/codegen/transform/peephole.cc | 3 ++- lib/codegen/transform/pipeline.cc | 12 ++++++------ lib/ir/builder.cc | 8 ++++---- lib/ir/dispatch.cc | 7 ++++--- lib/ir/instructions.cc | 22 +++++++++++----------- python/src/triton.cc | 3 ++- python/triton/language/core.py | 4 ++-- 11 files changed, 51 insertions(+), 37 deletions(-) diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index a80bc471f..357fffc6a 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -130,9 +130,9 @@ public: value *create_xor(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg, load_inst::CACHE_MODIFIER cache); + value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile); value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); + value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index b6cc7dcc1..c90480f1e 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -68,7 +68,7 @@ struct dispatch{ static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder); + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 699d22257..7c147f634 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -402,8 +402,11 @@ public: }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } + bool get_is_volatile() const { return is_volatile_; } + protected: load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); std::string get_cache_modifier_repr() const { if (cache_ == CA) return ".ca"; @@ -412,20 +415,24 @@ protected: } CACHE_MODIFIER cache_; + std::string get_volatile_repr() { + return is_volatile_ ? ".volatile" : ""; + } + bool is_volatile_; + private: static type *get_pointee_type(type *ty); - }; // unmasked load class unmasked_load_inst: public load_inst { private: std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } - unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); + unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); public: static unmasked_load_inst* create(value *ptr, - CACHE_MODIFIER cache, + CACHE_MODIFIER cache, bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_load_inst) @@ -436,7 +443,7 @@ public: class masked_load_inst: public load_inst { private: std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } - masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, + masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); public: @@ -445,7 +452,7 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_inst* create(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, + CACHE_MODIFIER cache, bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_inst) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a838bfa5a..a6148b2d1 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -629,7 +629,11 @@ void generator::visit_load_inst(ir::load_inst* x){ // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate - asm_oss << " ld.global"; + asm_oss << " ld"; +// std::cout << x->get_is_volatile() << std::endl; + if(x->get_is_volatile()) + asm_oss << ".volatile"; + asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if(n_words > 1) diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index f7ebdad80..864fea85c 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -207,7 +207,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), if_value->get_mask_operand(), select->get_else_value_op(), - if_value->get_cache_modifier()); + if_value->get_cache_modifier(), + if_value->get_is_volatile()); select->replace_all_uses_with(new_load); return true; } diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index bc249841b..eb3fe6164 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) { false_value = remat_false_value; } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier()); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile()); for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration @@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) { first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); false_value = remat_false_value; } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier()); + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile()); } // create new phis for induction variables @@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // phi node @@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) { } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); @@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); // phi node builder.set_insert_point(block->get_first_non_phi()); ir::phi_node* new_load = builder.create_phi(ty, 2); @@ -327,4 +327,4 @@ void pipeline::run(ir::module &mod) { } } -} \ No newline at end of file +} diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index cc1d354ee..feac3c6b6 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){ - return insert(unmasked_load_inst::create(ptr, cache)); +value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){ + return insert(unmasked_load_inst::create(ptr, cache, is_volatile)); } value *builder::create_store(value *ptr, value *val){ return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){ - return insert(masked_load_inst::create(ptr, mask, false_value, cache)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){ + return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile)); } value *builder::create_masked_store(value *ptr, value *val, value *mask){ diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index c4e8ccafb..2b8ae0259 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -500,7 +500,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Memory Operators //===----------------------------------------------------------------------===// -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) { +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) { if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ @@ -520,6 +520,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); ptr = dispatch::cast(ptr, ptr_ty, builder); } + // cache modifier load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default if (!cache_modifier.empty()) { if (cache_modifier == ".ca") @@ -530,7 +531,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); } if (!mask && !other) - return builder->create_load(ptr, cache); + return builder->create_load(ptr, cache, is_volatile); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); auto shape = ptr->get_type()->get_block_shapes(); @@ -539,7 +540,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con if(ptr->get_type()->is_block_ty()) other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); } - return builder->create_masked_load(ptr, mask, other, cache); + return builder->create_masked_load(ptr, mask, other, cache, is_volatile); } ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 00d801616..0206b7e77 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n { } // load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache) +load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile) { } // load @@ -448,35 +448,35 @@ type *load_inst::get_pointee_type(type *ty) { } // unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) { +unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) { set_operand(0, ptr); } -unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, cache, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, cache, is_volatile, name, next); } // masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, cache, name, next); + return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next); } // masked load async masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) { + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); diff --git a/python/src/triton.cc b/python/src/triton.cc index cec6fba94..783b0406a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -600,7 +600,8 @@ void init_triton_ir(py::module &&m) { py::class_(m, "constant_int") .def_property_readonly("value", &ir::constant_int::get_value) - .def("__int__", [](ir::constant_int *self) { return self->get_value(); }); + .def("__int__", [](ir::constant_int *self) { return self->get_value(); }) + .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") .def_property_readonly("value", &ir::constant_fp::get_value); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e939319aa..d5d3313e5 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -519,7 +519,7 @@ def dot(input, other, _builder=None): @builtin -def load(pointer, mask=None, other=None, cache_modifier="", _builder=None): +def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -536,7 +536,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", _builder=None): :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, _builder) + return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder) @builtin From d8db0308cb30eeff04dde487db0e384b3cf4d9a0 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Tue, 4 Jan 2022 13:07:29 -0800 Subject: [PATCH 036/215] [TEST] use numpy for reference results in test_core.py (#409) Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support. This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency. --- python/requirements-test.txt | 1 + python/test/unit/language/test_core.py | 415 ++++++++++++++----------- python/triton/testing.py | 29 +- 3 files changed, 241 insertions(+), 204 deletions(-) diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 48b6d3be3..84893a889 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1,2 +1,3 @@ +numpy pytest scipy >= 1.7.1 diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index aa0e7430a..fe33c9c6a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,31 +1,59 @@ +import copy +import itertools +import re +from typing import Optional + +import numpy as np +import pytest import torch +from numpy.random import RandomState + import triton import triton.language as tl -import copy -import pytest -import ast -import itertools - -torch.manual_seed(0) - -# convert from string to torch.dtype -# Necessary because doesn't print torch.dtype properly -cvt = { - 'bool': torch.bool, - 'int8': torch.int8, - 'int16': torch.int16, - 'int32': torch.int32, - 'int64': torch.int64, - 'bfloat16': torch.bfloat16, - 'float16': torch.float16, - 'float32': torch.float32, - 'float64': torch.float64, -} int_dtypes = ['int8', 'int16', 'int32', 'int64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + float_dtypes +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + dtype = getattr(np, dtype_str) + if dtype_str in int_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype) + x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype) + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor: + # For now, this always converts to a torch tensor, but when we add unsigned + # integers, it will also support TensorWrapper, since torch doesn't have + # unsigned support. + return torch.tensor(x, device=device) + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + def patch_kernel(template, to_replace): kernel = copy.deepcopy(template) @@ -34,19 +62,18 @@ def patch_kernel(template, to_replace): return kernel -@pytest.mark.parametrize("dtype_x", [ - (dtype_x) for dtype_x in dtypes -]) +@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit def kernel(X, SIZE: tl.constexpr): pass - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) kernel[(1, )](x, SIZE=SIZE, num_warps=4) + # generic test functions -def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -58,18 +85,36 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'): kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) - if 'log' in expr: x = torch.abs(x) + 0.01 + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 # reference result - z_ref = eval(expr if torch_expr is None else torch_expr) + z_ref = eval(expr if numpy_expr is None else numpy_expr) # triton result - z_tri = torch.empty_like(z_ref) - kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + x_tri = to_triton(x, device=device) + z_tri = to_triton(np.empty_like(z_ref), device=device) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) # compare - triton.testing.assert_almost_equal(z_ref, z_tri) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y='real', device='cuda'): +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -82,27 +127,24 @@ def _test_binary(dtype_x, dtype_y, expr, torch_expr=None, mode_x='real', mode_y= kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs - x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device, seed=17) - y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device, seed=144) - if mode_x == 'nan': x[:] = float('nan') - if mode_y == 'nan': y[:] = float('nan') + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') # reference result - z_ref = eval(expr if torch_expr is None else torch_expr) + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) # triton result - z_tri = torch.empty(SIZE, dtype=z_ref.dtype, device=device) - kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) - # compare - triton.testing.assert_almost_equal(z_ref, z_tri, err_msg=expr) - - -def _fake_fmod(x, y): - """ - Triton % (for both integers and floats) has the same semantics as torch - fmod, but torch fmod doesn't work on integers until torch 1.8. - `_fake_fmod` gives the same semantics but works on all versions of torch. - """ - z = torch.remainder(x, y) - return torch.where((torch.sign(x) != torch.sign(y)) & (z != 0), z - y, z) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @@ -130,36 +172,38 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: - # LLVM has 'torch.fmod', not 'torch.remainder' semantics on integer remainders. - torch_expr = '_fake_fmod(x, y)' + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. - torch_expr = f'x.to(torch.float32) {op} y.to(torch.float32)' + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' else: - torch_expr = None + numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): - with pytest.raises(AssertionError, match='Arrays are not almost equal'): - _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) + with pytest.raises(AssertionError, match='Not equal to tolerance'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) else: - _test_binary(dtype_x, dtype_y, expr, torch_expr=torch_expr, device=device) - + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) # --------------- # test bitwise ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, expr", [ - (dtype_x, dtype_y, f' x {op} y') \ - for op in ['&', '|', '^'] \ - for dtype_x in dtypes \ - for dtype_y in dtypes +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + for dtype_y in dtypes ]) -def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): +def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' if 'float' in dtype_x + dtype_y: - with pytest.raises(RuntimeError): - _test_binary(dtype_x, dtype_y, expr, device=device) + with pytest.raises(triton.code_gen.CompilationError) as exc_info: + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) + # The CompilationError must have been caused by a C++ exception with this text. + assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: _test_binary(dtype_x, dtype_y, expr, device=device) @@ -168,23 +212,24 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'): # test compare ops # --------------- ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \ +@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \ # real [ - (dtype_x, dtype_y, f' x {op} y', 'real', 'real') \ + (dtype_x, dtype_y, op, 'real', 'real') \ for op in ops \ for dtype_x in dtypes \ for dtype_y in dtypes ] + \ # NaNs -[('float32', 'float32', f' x {op} y', mode_x, mode_y) \ +[('float32', 'float32', op, mode_x, mode_y) \ for op in ops - for mode_x, mode_y in [('nan' , 'real'), - ('real', 'nan'), + for mode_x, mode_y in [('nan' , 'real'), + ('real', 'nan'), ('nan' , 'nan')] ]) -def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'): +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): + expr = f'x {op} y' _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device) @@ -192,9 +237,9 @@ def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'): # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, f' -x') for dtype_x in float_dtypes + (dtype_x, ' -x') for dtype_x in dtypes ] + [\ - (dtype_x, f' ~x') for dtype_x in int_dtypes + (dtype_x, ' ~x') for dtype_x in int_dtypes ]) def test_unary_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, expr, device=device) @@ -210,7 +255,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): 'exp', 'log', 'cos', 'sin' ]) def test_math_op(expr, device='cuda'): - _test_unary('float32', f'tl.{expr}(x)', f'torch.{expr}(x) ', device=device) + _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) # ---------------- @@ -229,12 +274,11 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" -@pytest.mark.parametrize("expr", [f'x[{s}]' for s in - ['None, :', ':, None',\ - 'None, :, :', ':, :, None']\ +@pytest.mark.parametrize("expr, dtype_str", [ + (f'x[{s}]', 'int32') + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] ]) -def test_index1d(expr, device='cuda'): - dtype = torch.int32 +def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') rank_y = expr.count(',') + 1 shape_x = [32 for _ in range(rank_x)] @@ -257,14 +301,15 @@ def test_index1d(expr, device='cuda'): kernel = patch_kernel(kernel, to_replace) # torch result - x = triton.testing.random(shape_x, dtype=dtype, device=device) - y = torch.zeros(shape_z, dtype=dtype, device=device) + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) z_ref = eval(expr) + y # triton result - z_tri = torch.empty_like(z_ref) - kernel[(1, )](z_tri, x, num_warps=1, SIZE=shape_x[0]) + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x) + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) # compare - triton.testing.assert_almost_equal(z_ref, z_tri) + assert (z_ref == to_numpy(z_tri)).all() # --------------- @@ -316,14 +361,15 @@ def test_tuples(): # --------------- # test atomics # --------------- -@pytest.mark.parametrize("op, dtype_x, mode", itertools.chain.from_iterable([ - [('add', 'int32', mode), ('add', 'float16', mode), ('add', 'float32', mode), \ - ('max', 'int32', mode), ('max', 'float32', mode),\ - ('min', 'int32', mode), ('min', 'float32', mode),\ +@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ + [ + ('add', 'float16', mode), + ('add', 'int32', mode), ('add', 'float32', mode), + ('max', 'int32', mode), ('max', 'float32', mode), + ('min', 'int32', mode), ('min', 'float32', mode), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) -def test_atomic_rmw(op, dtype_x, mode, device='cuda'): - dtype_x = cvt[dtype_x] +def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): n_programs = 5 # triton kernel @@ -334,52 +380,59 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): old = GENERATE_TEST_HERE kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) - torch_op = {'add': torch.sum, 'max': torch.max, 'min': torch.min}[op] - max_neutral = float('-inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).min - min_neutral = float('inf') if dtype_x.is_floating_point else torch.iinfo(dtype_x).max + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] # triton result - x_tri = triton.testing.random((n_programs, ), dtype=dtype_x, device=device) + rs = RandomState(17) + x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs) if mode == 'all_neg': - x_tri = -torch.abs(x_tri) + x = -np.abs(x) if mode == 'all_pos': - x_tri = torch.abs(x_tri) + x = np.abs(x) if mode == 'min_neg': - idx = torch.randint(n_programs, size=(1, )).item() - x_tri[idx] = -torch.max(torch.abs(x_tri)) - 1 + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 if mode == 'max_pos': - idx = torch.randint(n_programs, size=(1, )).item() - x_tri[idx] = torch.max(torch.abs(x_tri)) + 1 + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) - z_tri = torch.empty([], dtype=dtype_x, device=device) - z_tri.fill_(neutral) + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) kernel[(n_programs, )](x_tri, z_tri) # torch result - z_ref = torch_op(x_tri).to(dtype_x) + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) # compare exact = op not in ['add'] if exact: - assert z_ref.item() == z_tri.item() + assert z_ref.item() == to_numpy(z_tri).item() else: - triton.testing.assert_almost_equal(z_ref, z_tri) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001) # --------------- # test cast # --------------- @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) \ - for dtype_x in dtypes\ + (dtype_x, dtype_z, False) + for dtype_x in dtypes for dtype_z in dtypes -] + [ +] + [ ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), - ('float32', 'int32', True) -]) + ('float32', 'int32', True), +] +) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - x0 = 43 if dtype_x.startswith('int') else 43.5 - x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + x0 = 43 if dtype_x in int_dtypes else 43.5 + if dtype_x.startswith('bfloat'): + x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) + else: + x = np.array([x0], dtype=getattr(np, dtype_x)) + x_tri = to_triton(x) # triton kernel @triton.jit @@ -389,26 +442,31 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): tl.store(Z, z) # triton result - z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device) - kernel[(1, )](x, z_tri, BITCAST=bitcast) - # torch result - if bitcast: - import numpy as np - z_ref = x.detach().cpu().numpy().view(getattr(np, dtype_z)) - z_ref = torch.from_numpy(z_ref).to(device) + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) else: - z_ref = x.to(z_tri.dtype) - assert z_tri == z_ref + z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + assert z_tri == z_ref + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z)) + else: + z_ref = x.astype(getattr(np, dtype_z)) + assert to_numpy(z_tri) == z_ref # --------------- # test reduce # --------------- -@pytest.mark.parametrize("dtype, shape", +@pytest.mark.parametrize("dtype_str, shape", [(dtype, shape) \ for dtype in dtypes\ for shape in [128, 512]]) -def test_reduce1d(dtype, shape, device='cuda'): - dtype = cvt[dtype] +def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel @triton.jit @@ -416,22 +474,22 @@ def test_reduce1d(dtype, shape, device='cuda'): x = tl.load(X + tl.arange(0, BLOCK)) tl.store(Z, tl.sum(x, axis=0)) - x = triton.testing.random((shape,), dtype=dtype, device=device) + rs = RandomState(17) + x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) + # numpy result + z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result - z_tri = triton.testing.random((1,), dtype=dtype, device=device) - kernel[(1,)](x, z_tri, BLOCK=shape) - # torch result - z_ref = torch.sum(x).to(dtype) + x_tri = to_triton(x, device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1,)](x_tri, z_tri, BLOCK=shape) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -@pytest.mark.parametrize("dtype, shape, axis", - [(dtype, shape, 1) \ - for dtype in ['float32']\ - for shape in [(1, 1024)]]) -def test_reduce2d(dtype, shape, axis, device='cuda'): - dtype = cvt[dtype] +@pytest.mark.parametrize("dtype_str, shape, axis", [ + ('float32', (1, 1024), 1) +]) +def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): @@ -441,29 +499,30 @@ def test_reduce2d(dtype, shape, axis, device='cuda'): z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input - x = triton.testing.random(shape, dtype=dtype, device=device) + x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = torch.empty((shape[0],), dtype=dtype, device=device) - kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) - # torch result - z_ref = torch.sum(x, axis=axis).to(dtype) + x_tri = to_triton(x) + z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + # numpy reference result + z_ref = np.sum(x, axis=axis).astype(x.dtype) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # --------------- # test permute # --------------- -@pytest.mark.parametrize("dtype, shape, perm", +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) \ for dtype in ['float32']\ for shape in [(128, 128)]\ - for perm in [(1, 0)]]) -def test_permute(dtype, shape, perm, device='cuda'): - dtype = cvt[dtype] + for perm in [(1, 0)]]) +def test_permute(dtype_str, shape, perm, device='cuda'): + # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, + def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) @@ -471,14 +530,15 @@ def test_permute(dtype, shape, perm, device='cuda'): Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn tl.store(Zs, tl.load(Xs)) # input - x = triton.testing.random(shape, dtype=dtype, device=device) + x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = torch.empty_like(x) - pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), - z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1]) + z_tri = to_triton(np.empty_like(x), device=device) + x_tri = to_triton(x, device=device) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1]) # torch result - z_ref = x.permute(*perm).contiguous() + z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) # parse ptx to make sure ld/st are vectorized @@ -491,13 +551,12 @@ def test_permute(dtype, shape, perm, device='cuda'): # --------------- @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, dtype=torch.float32, device='cuda'): - torch.manual_seed(0) +def test_dot(epilogue, device='cuda'): # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xk, + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, - Z, stride_zm, stride_zn, + Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): off_m = tl.arange(0, BLOCK_M) @@ -513,36 +572,38 @@ def test_dot(epilogue, dtype=torch.float32, device='cuda'): ZRs = Z + off_m * stride_zm z += tl.load(ZRs)[:, None] if ADD_COLS: - ZCs = Z + off_n * stride_zn + ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] tl.store(Zs, z) # input M, N, K = 64, 64, 32 - x = triton.testing.random((M, K), dtype=dtype, device=device) - y = triton.testing.random((K, N), dtype=dtype, device=device) + rs = RandomState(17) + x = numpy_random((M, K), dtype_str='float32', rs=rs) + y = numpy_random((K, N), dtype_str='float32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) # triton result - z = triton.testing.random((M, N), dtype=dtype, device=device) - z_tri = z.clone() + z = numpy_random((M, N), dtype_str='float32', rs=rs) + z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) - pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), - y, y.stride(0), y.stride(1), + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + y_tri, y_tri.stride(0), y_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX = epilogue=='add-matrix', ADD_ROWS = epilogue=='add-rows', ADD_COLS = epilogue=='add-cols') # torch result - z_ref = torch.matmul(x.float(), y.float()) + z_ref = np.matmul(x, y) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': z_ref += z[:,0][:, None] if epilogue == 'add-cols': z_ref += z[0,:][None, :] - z_ref = z_ref.to(torch.float16) # compare - triton.testing.assert_almost_equal(z_tri, z_ref) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx @@ -558,7 +619,7 @@ def test_dot_without_load(): c = tl.dot(a, b) pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] tl.store(pout, c) - + out = torch.ones((32,32), dtype=torch.float32, device="cuda") kernel[(1,)](out) @@ -571,7 +632,7 @@ def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, BLOCK: tl.constexpr, + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) @@ -605,8 +666,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): N_offsets = tl.arange(0, N) K_offsets = tl.arange(0, K) - in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] - in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] # Load inputs. x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) @@ -616,7 +677,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'): o = tl.dot(x, w) # Store output - output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) pgm = _kernel[(1,)](in1, in2, out, @@ -687,7 +748,7 @@ def test_default(): def _kernel(ret0, ret1, value): tl.store(ret0, _impl()) tl.store(ret1, _impl(value)) - + _kernel[(1,)](ret0, ret1, value) assert ret0.item() == 10 assert ret1.item() == value @@ -699,5 +760,5 @@ def test_noop(device='cuda'): @triton.jit def kernel(x): pass - x = triton.testing.random((1,), dtype=torch.int32, device=device) + x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) kernel[(1, )](x) diff --git a/python/triton/testing.py b/python/triton/testing.py index f274e808f..eef7f5be6 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -85,31 +85,6 @@ def allclose(x, y, tol=1e-2): return err <= tol -def assert_allclose(x, y, tol=1e-2): - assert x.dtype == y.dtype - assert allclose(x, y, tol) - - -def random(shape, dtype, device, seed=0): - """ - Override the seed in tests if you're calling this function twice and don't - want the same result for both calls. - """ - torch.manual_seed(seed) - if isinstance(shape, int): - shape = (shape, ) - if dtype == torch.bool: - return torch.randint(0, 2, shape, dtype=dtype, device=device) - if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - iinfo = torch.iinfo(dtype) - x = torch.randint(iinfo.min, iinfo.max, shape, dtype=dtype, device=device) - x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. - return x - if dtype in [torch.float16, torch.float32, torch.float64]: - return torch.normal(0, 1, shape, dtype=dtype, device=device) - raise RuntimeError(f'Unknown dtype {dtype}') - - def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -203,7 +178,7 @@ class Benchmark: styles=None, ): """ - Constructor + Constructor :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. :type x_names: List[str] @@ -344,4 +319,4 @@ def get_max_tensorcore_tflops(backend, device): else: ops_per_sub_core = 512 tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) - return tflops \ No newline at end of file + return tflops From 0ab9d67bade400f005cfc09849434c3a5a9c7bdf Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Wed, 5 Jan 2022 15:27:17 -0800 Subject: [PATCH 037/215] uint8, uint16, uint32, and uint64 in kernels (#413) A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings. --- include/triton/ir/builder.h | 6 + include/triton/ir/context_impl.h | 1 + include/triton/ir/type.h | 22 ++- lib/ir/builder.cc | 18 +++ lib/ir/context.cc | 16 +- lib/ir/dispatch.cc | 146 ++++++++++++----- lib/ir/type.cc | 14 ++ python/src/triton.cc | 73 ++++++--- python/test/unit/language/test_core.py | 190 +++++++++++++++++++---- python/test/unit/language/test_random.py | 1 + python/triton/code_gen.py | 40 ++++- python/triton/language/core.py | 27 +++- 12 files changed, 444 insertions(+), 110 deletions(-) diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 357fffc6a..3a4094123 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -40,6 +40,8 @@ public: value *get_int1(bool val); value *get_int32(int32_t val); value *get_int64(int64_t val); + value *get_uint32(uint32_t val); + value *get_uint64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -50,6 +52,10 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); + type *get_uint8_ty(); + type *get_uint16_ty(); + type *get_uint32_ty(); + type *get_uint64_ty(); type *get_half_ty(); type *get_float_ty(); type *get_double_ty(); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 7d18a3b4c..e43b5ad57 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -28,6 +28,7 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; + integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, pointer_type*> ptr_tys; // Block types diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index c9c07c4f1..c27ce48cf 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -15,6 +15,8 @@ class value; class integer_type; class constant_int; +enum class signedness { SIGNED, UNSIGNED }; + /* Type */ class type { public: @@ -58,6 +60,8 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; + signedness get_integer_signedness() const; + bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -80,8 +84,9 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() && - get_integer_bitwidth() == bitwidth;} + bool is_integer_ty(unsigned bitwidth, signedness sn) { + return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; + } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -109,6 +114,10 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + static integer_type *get_uint8_ty(context &ctx); + static integer_type *get_uint16_ty(context &ctx); + static integer_type *get_uint32_ty(context &ctx); + static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -135,7 +144,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -158,18 +167,21 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth) - : type(ctx, IntegerTyID), bitwidth_(bitwidth){ } + integer_type(context &ctx, unsigned bitwidth, signedness sn) + : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } public: // accessors unsigned get_bitwidth() const { return bitwidth_; } + signedness get_signedness() const { return signedness_; } + // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; + signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index feac3c6b6..a8ba68d1c 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -51,9 +51,15 @@ value *builder::get_int1(bool val) value *builder::get_int32(int32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} +value *builder::get_uint32(uint32_t val) +{ return constant_int::get(type::get_uint32_ty(ctx_), val);} + value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} +value *builder::get_uint64(uint64_t val) +{ return constant_int::get(type::get_uint64_ty(ctx_), val);} + value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -84,6 +90,18 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } +type *builder::get_uint8_ty() +{ return type::get_uint8_ty(ctx_); } + +type *builder::get_uint16_ty() +{ return type::get_uint16_ty(ctx_); } + +type *builder::get_uint32_ty() +{ return type::get_uint32_ty(ctx_); } + +type *builder::get_uint64_ty() +{ return type::get_uint64_ty(ctx_); } + type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 9bd66ec9a..90b109b9b 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1), - int8_ty(ctx, 8), - int16_ty(ctx, 16), - int32_ty(ctx, 32), - int64_ty(ctx, 64), - int128_ty(ctx, 128){ + int1_ty(ctx, 1, signedness::SIGNED), + int8_ty(ctx, 8, signedness::SIGNED), + int16_ty(ctx, 16, signedness::SIGNED), + int32_ty(ctx, 32, signedness::SIGNED), + int64_ty(ctx, 64, signedness::SIGNED), + int128_ty(ctx, 128, signedness::SIGNED), + uint8_ty(ctx, 8, signedness::UNSIGNED), + uint16_ty(ctx, 16, signedness::UNSIGNED), + uint32_ty(ctx, 32, signedness::UNSIGNED), + uint64_ty(ctx, 64, signedness::UNSIGNED){ } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 2b8ae0259..a46c34b4f 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -1,14 +1,12 @@ #include "triton/ir/dispatch.h" -#include -namespace triton{ -namespace ir{ +namespace triton { +namespace ir { -ir::value* throw_unreachable(std::string key) { +[[ noreturn ]] void throw_unreachable(std::string key) { throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " "This is likely a bug on our side."); - return 0; } //===----------------------------------------------------------------------===// @@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) { ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ int a_rank = a_ty->get_integer_bitwidth(); int b_rank = b_ty->get_integer_bitwidth(); - return a_rank > b_rank ? a_ty : b_ty; + auto a_sn = a_ty->get_integer_signedness(); + auto b_sn = b_ty->get_integer_signedness(); + // Rules for signedness taken from "Usual arithmetic conversions" on + // https://en.cppreference.com/w/c/language/conversion. + if (a_sn == b_sn) { + return a_rank > b_rank ? a_ty : b_ty; + } else if (a_sn == signedness::UNSIGNED) { + return a_rank >= b_rank ? a_ty : b_ty; + } else if (b_sn == signedness::UNSIGNED) { + return b_rank >= a_rank ? b_ty : a_ty; + } else { + throw_unreachable("integer_promote"); + } } enum class DivOrMod { NO, YES }; @@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) throw_unreachable("computation_type"); // 4 ) both operands are integer and undergo // integer promotion + if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { + throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } return integer_promote(a_ty, b_ty); } @@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde // int + int else if (input_scalar_ty->is_integer_ty()) return builder->create_add(input, other); - return throw_unreachable("add"); + throw_unreachable("add"); } ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { @@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde // int + int else if (scalar_ty->is_integer_ty()) return builder->create_sub(input, other); - return throw_unreachable("sub"); + throw_unreachable("sub"); } ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { @@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde // int * int else if (scalar_ty->is_integer_ty()) return builder->create_mul(input, other); - return throw_unreachable("mul"); + throw_unreachable("mul"); } ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { @@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu } // unreachable else - return throw_unreachable("div"); + throw_unreachable("div"); return builder->create_fdiv(input, other); } @@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); input = dispatch::cast(input, ret_ty, builder); other = dispatch::cast(other, ret_ty, builder); - return builder->create_sdiv(input, other); + if (ret_ty->is_integer_signed()) { + return builder->create_sdiv(input, other); + } else { + return builder->create_udiv(input, other); + } } - return throw_unreachable("floordiv"); + throw_unreachable("floordiv"); } ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); // float % int if (scalar_ty->is_floating_point_ty()) return builder->create_frem(input, other); // int % int - else if (scalar_ty->is_integer_ty()) - return builder->create_srem(input, other); - return throw_unreachable("mod"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { + throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + if (scalar_ty->is_integer_signed()) { + return builder->create_srem(input, other); + } else { + return builder->create_urem(input, other); + } + } + throw_unreachable("mod"); } @@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) throw_incompatible_types(input_sca_ty, other_sca_ty); - if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) - input = dispatch::cast(input, other_sca_ty, builder); - else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); + ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); + if (ret_sca_ty != input_sca_ty) + input = dispatch::cast(input, ret_sca_ty, builder); + if (ret_sca_ty != other_sca_ty) + other = dispatch::cast(other, ret_sca_ty, builder); } ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { @@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOGT(input, other); // int > int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGT(input, other); - return throw_unreachable("greater_than"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGT(input, other); + } else { + return builder->create_icmpUGT(input, other); + } + } + throw_unreachable("greater_than"); } ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOGE(input, other); // int >= int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGE(input, other); - return throw_unreachable("greater_equal"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGE(input, other); + } else { + return builder->create_icmpUGE(input, other); + } + } + throw_unreachable("greater_equal"); } ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { @@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder * if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOLT(input, other); // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLT(input, other); - return throw_unreachable("less_than"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLT(input, other); + } else { + return builder->create_icmpULT(input, other); + } + } + throw_unreachable("less_than"); } ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOLE(input, other); // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLE(input, other); - return throw_unreachable("less_equal"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLE(input, other); + } else { + return builder->create_icmpULE(input, other); + } + } + throw_unreachable("less_equal"); } ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpEQ(input, other); - return throw_unreachable("equal"); + throw_unreachable("equal"); } ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder * // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpNE(input, other); - return throw_unreachable("equal"); + throw_unreachable("equal"); } //===----------------------------------------------------------------------===// @@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build return builder->create_fp_ext(input, dst_ty); // Int cast if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) - return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty()); + (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || + src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { + bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); + return builder->create_int_cast(input, dst_ty, sign_extend); + } // Float -> Int if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ if(dst_sca_ty->is_bool_ty()) @@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build } // int -> Float if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ - if(src_sca_ty->is_bool_ty()) + if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) return builder->create_ui_to_fp(input, dst_ty); else return builder->create_si_to_fp(input, dst_ty); @@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); + throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } //===----------------------------------------------------------------------===// @@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, atom_red_typechecking(ptr, val, mask, builder); ir::type* sca_ty = val->get_type()->get_scalar_ty(); // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); + } + } // for float // return atomic_smax(i_ptr, i_val) if val >= 0 // return atomic_umin(i_ptr, i_val) if val < 0 @@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ atom_red_typechecking(ptr, val, mask, builder); ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + // direct call to atomic_min for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); + } + } // for float // return atomic_smin(i_ptr, i_val) if val >= 0 // return atomic_umax(i_ptr, i_val) if val < 0 @@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder return builder->create_reduce(input, FLOAT_OP, axis); else if (scalar_ty->is_integer_ty()) return builder->create_reduce(input, INT_OP, axis); - return throw_unreachable(name); + throw_unreachable(name); } ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { diff --git a/lib/ir/type.cc b/lib/ir/type.cc index ab8acb24b..74066a65a 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } +signedness type::get_integer_signedness() const +{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } + +bool type::is_integer_signed() const { + if (id_ != IntegerTyID) { + throw std::logic_error("type is " + repr() + ", not integer"); + } + return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; +} + unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } +integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } +integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } +integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } +integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 783b0406a..4d7df76ff 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -109,6 +109,24 @@ std::string pow2_divisor(long N){ return "1"; } +// Returns something like "int16", whether dtype is a torch.dtype or +// triton.language.dtype. +std::string dtype_cache_key_part(const py::object& dtype) { + if (py::hasattr(dtype, "cache_key_part")) { + // Presumed to be a triton.language.dtype. + return std::string(py::str(py::getattr(dtype, "cache_key_part"))); + } else { + // Remove 'torch.' prefix from repr of torch.dtype. + py::object repr = py::repr(dtype); + size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr()); + const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); + if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) { + throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len)); + } + return std::string(repr_ptr + 6, repr_len - 6); + } +} + // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, @@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f cache_key += "1"; continue; } - // long and int have different kernels - if(!overflow & (std::abs(value) <= 0xffffffff)){ + // int32, uint32, int64, and uint64 have different kernels + if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) { cache_key += "int32"; params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); params_ptr += 4; - } - else{ + } else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) { + cache_key += "uint32"; + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } else if (!overflow) { cache_key += "int64"; params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); - if(overflow){ - unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); - std::memcpy(&value, &uvalue, 8); - } std::memcpy(params_ptr, &value, 8); params_ptr += 8; + } else { + if (PyErr_Occurred()) { + throw std::logic_error("An error occurred?"); + } + unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr); + if (PyErr_Occurred()) { + throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg))); + } + cache_key += "uint64"; + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &unsigned_value, 8); + params_ptr += 8; } if(!specialize) continue; @@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); params_ptr += 8; - py::object dtype = arg.attr("dtype"); - py::object repr = py::repr(dtype); - assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6)); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' - size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; - cache_key += std::string(start, len); + cache_key += dtype_cache_key_part(arg.attr("dtype")); cache_key += "*"; cache_key += "[multipleof("; cache_key += pow2_divisor(value); @@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) + .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) + .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) + .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) + .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) .def("is_void", &ir::type::is_void_ty) .def("is_fp8", &ir::type::is_fp8_ty) @@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) { .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) + .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) + .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) + .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) + .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) @@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) { .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference) .def("get_int64", &ir::builder::get_int64, ret::reference) + .def("get_uint32", &ir::builder::get_uint32, ret::reference) + .def("get_uint64", &ir::builder::get_uint64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) .def("get_range", &ir::builder::get_range, ret::reference); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index fe33c9c6a..41c9e9236 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,7 +1,7 @@ import copy import itertools import re -from typing import Optional +from typing import Optional, Union import numpy as np import pytest @@ -10,17 +10,20 @@ from numpy.random import RandomState import triton import triton.language as tl +from triton.code_gen import TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] -dtypes = int_dtypes + float_dtypes +dtypes = int_dtypes + uint_dtypes + float_dtypes + def _bitwidth(dtype: str) -> int: # ex.: "int64" -> 64 return int(re.search(r'(\d+)$', dtype).group(1)) -def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same result for both calls. @@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): if rs is None: rs = RandomState(seed=17) dtype = getattr(np, dtype_str) - if dtype_str in int_dtypes: + if dtype_str in int_dtypes + uint_dtypes: iinfo = np.iinfo(getattr(np, dtype_str)) - x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + x = rs.randint(low, high, shape, dtype=dtype) x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. return x elif dtype_str in float_dtypes: @@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): raise RuntimeError(f'Unknown dtype {dtype_str}') -def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor: - # For now, this always converts to a torch tensor, but when we add unsigned - # integers, it will also support TensorWrapper, since torch doesn't have - # unsigned support. - return torch.tensor(x, device=device) +def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') def to_numpy(x): - if isinstance(x, torch.Tensor): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): return x.cpu().numpy() else: raise ValueError(f"Not a triton-compatible tensor: {x}") @@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: Given two dtype strings, returns the numpy dtype Triton thinks binary operations on the two types should return. Returns None if the return value matches numpy. This is generally needed because Triton and pytorch return - narrower floating point types than numpy in mixed operations. + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. """ overrides = { ('float16', 'int16'): np.float16, ('float16', 'int32'): np.float16, ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, } key = (a, b) if a < b else (b, a) return overrides.get(key) -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'): +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y= # inputs rs = RandomState(17) x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) - y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) if mode_x == 'nan': x[:] = float('nan') if mode_y == 'nan': @@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ('int64', 'float16'), ('int64', 'float32'), ('int64', 'float64'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), ] # --------------- @@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' - if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): @@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' else: numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): with pytest.raises(AssertionError, match='Not equal to tolerance'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + elif (op in ('%', '/') and + ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.code_gen.CompilationError) as exc_info: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) +@pytest.mark.parametrize("dtype_x, dtype_y", + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] +) +def test_floordiv(dtype_x, dtype_y, device='cuda'): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + + # --------------- # test bitwise ops # --------------- @@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None if 'float' in dtype_x + dtype_y: with pytest.raises(triton.code_gen.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: - _test_binary(dtype_x, dtype_y, expr, device=device) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + + +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +def test_shift_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) # --------------- @@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<='] ]) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): expr = f'x {op} y' - _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device) + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) # --------------- @@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): # --------------- @pytest.mark.parametrize("dtype_x, expr", [ (dtype_x, ' -x') for dtype_x in dtypes -] + [\ +] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes - ]) +]) def test_unary_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, expr, device=device) @@ -275,8 +367,9 @@ def make_ptr_str(name, shape): @pytest.mark.parametrize("expr, dtype_str", [ - (f'x[{s}]', 'int32') + (f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') @@ -364,9 +457,9 @@ def test_tuples(): @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ [ ('add', 'float16', mode), - ('add', 'int32', mode), ('add', 'float32', mode), - ('max', 'int32', mode), ('max', 'float32', mode), - ('min', 'int32', mode), ('min', 'float32', mode), + ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode), + ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode), + ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): @@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): if exact: assert z_ref.item() == to_numpy(z_tri).item() else: - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # --------------- @@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), ('float32', 'int32', True), -] -) +] + [ + (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] +] + [ + (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] +]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. x0 = 43 if dtype_x in int_dtypes else 43.5 @@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, axis", [ - ('float32', (1, 1024), 1) + (dtype, (1, 1024), 1) for dtype in ['float32', 'uint32'] ]) def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @@ -762,3 +858,43 @@ def test_noop(device='cuda'): pass x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [ + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), + (2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + pgm = kernel[(1, )](value, x) + + # Parse out the type of the 'VALUE' parameter from the Triton IR. + triton_ir = pgm.asm['ttir'] + ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + assert ir_value_type == value_type + + +@pytest.mark.parametrize( + "value, overflow", + [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] +) +def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + + if overflow: + with pytest.raises(RuntimeError, match='integer overflow'): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 4d4501556..67173adfb 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'): N = x.numel() grid = (triton.cdiv(N, BLOCK),) kernel[grid](x, N, seed) + assert all((x >= 0) & (x <= 1)) assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 # test normal PRNG diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8393f2b87..eec36f052 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor): return triton.language.constexpr(not op) if isinstance(op, triton.language.core.constexpr): op = op.value - # print(op) fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -503,6 +502,7 @@ class Binary: self.shared_mem = shared_mem self.num_warps = num_warps + class LoadedBinary: def __init__(self, device: int, bin: Binary): module, kernel = _triton.code_gen.load_binary(bin.backend, @@ -571,24 +571,33 @@ class Kernel: torch.int16: 'i16', torch.int32: 'i32', torch.int64: 'i64', + triton.language.uint8: 'u8', + triton.language.uint16: 'u16', + triton.language.uint32: 'u32', + triton.language.uint64: 'u64', } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] if isinstance(obj, triton.language.core.constexpr): obj = obj.value if isinstance(obj, int): - if abs(obj) <= 0xffffffff: - return 'I' - return 'L' + if -2**31 <= obj < 2**31: + return 'i32' + elif 2**31 <= obj < 2**32: + return 'u32' + elif -2**63 <= obj < 2**63: + return 'i64' + elif 2**63 <= obj < 2**64: + return 'u64' + else: + raise ValueError(f'integer overflow representing {obj}') if isinstance(obj, float): return 'f' if isinstance(obj, bool): return 'B' if isinstance(obj, str): return 'str' - assert False - - + raise NotImplementedError(f'could not compute type name for {obj}') @staticmethod def _to_triton_ir(context, obj): @@ -607,6 +616,10 @@ class Kernel: 'i16': _triton.ir.type.get_int16, 'i32': _triton.ir.type.get_int32, 'i64': _triton.ir.type.get_int64, + 'u8': _triton.ir.type.get_uint8, + 'u16': _triton.ir.type.get_uint16, + 'u32': _triton.ir.type.get_uint32, + 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers if hasattr(obj, 'data_ptr'): @@ -1165,4 +1178,15 @@ class TensorWrapper: def reinterpret(tensor, dtype): - return TensorWrapper(tensor, dtype) + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d5d3313e5..210a72a30 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -9,9 +9,16 @@ def _to_ir(x, builder): if isinstance(x, bool): return builder.get_int1(x) elif isinstance(x, int): - if x.__abs__() <= 2**31: + if -2**31 <= x < 2**31: return builder.get_int32(x) - return builder.get_int64(x) + elif 2**31 <= x < 2**32: + return builder.get_uint32(x) + elif -2**63 <= x < 2**63: + return builder.get_int64(x) + elif 2**63 <= x < 2**64: + return builder.get_uint64(x) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): return builder.get_float32(x) elif isinstance(x, constexpr): @@ -83,6 +90,14 @@ class dtype: def __str__(self): return self.name + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + return f'triton.language.{self.name}' + class pointer_dtype: def __init__(self, element_ty): @@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8) int16 = dtype(ir.type.get_int16) int32 = dtype(ir.type.get_int32) int64 = dtype(ir.type.get_int64) +uint8 = dtype(ir.type.get_uint8) +uint16 = dtype(ir.type.get_uint16) +uint32 = dtype(ir.type.get_uint32) +uint64 = dtype(ir.type.get_uint64) float8 = dtype(ir.type.get_fp8) float16 = dtype(ir.type.get_fp16) bfloat16 = dtype(ir.type.get_bf16) @@ -120,6 +139,10 @@ class block: if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 + if ir_type.is_uint8(): return uint8 + if ir_type.is_uint16(): return uint16 + if ir_type.is_uint32(): return uint32 + if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 From 001fb757fe0fea4f82e721378ed2f9dea35513ba Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Jan 2022 12:56:22 -0500 Subject: [PATCH 038/215] [OPS][BLOCKSPARSE] Added `.contiguous()` in blocksparse inputs when necessary (#420) --- python/triton/ops/blocksparse/matmul.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 9c3317fe0..ce15c9af4 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -76,6 +76,10 @@ def _sdd_kernel( tl.store(pc, c, mask=True) def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # (A * B)^T = B^T * A^T if trans_c: a, b = b, a @@ -190,6 +194,10 @@ def _dsd_kernel( tl.store(pc, c, mask = offs_cn[None, :] < DS0) def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # shapes / dtypes AS1 = block * spdims[2 if trans_a else 1] BS0 = b.size(0) @@ -378,6 +386,10 @@ def _dds_kernel( tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) From 120cda015eaf541ab9f61237e21dffb9688d3b12 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Thu, 6 Jan 2022 10:49:09 -0800 Subject: [PATCH 039/215] [FRONTEND] use unsigned integers to simplify RNG (#417) --- python/triton/language/random.py | 61 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index e1ac3c30a..cb2ddfc6b 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -2,24 +2,16 @@ import triton from . import core as tl -# Notes -# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations. -# 2. multiply_low_high is currently inefficient. -# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float - -PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 -PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 -PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 -PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 -N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox +PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox # ------------------- # randint # ------------------- -@triton.jit -def hacky_to_uint64(x): - return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) @triton.jit def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): @@ -40,12 +32,13 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): k1 = k1 + PHILOX_KEY_B return c0, c1, c2, c3 + @triton.jit def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns a single - block of random :code:`int32`. - + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + If you need multiple streams of random numbers, using `randint4x` is likely to be faster than calling `randint` 4 times. @@ -55,23 +48,23 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): ret, _, _, _ = randint4x(seed, offset, n_rounds) return ret + @triton.jit def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns four - blocks of random :code:`int32`. - - This is the maximally efficient entry point + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point to Triton's Philox pseudo-random number generator. :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting - seed = seed + 0 - seed = hacky_to_uint64(seed) # uint will solve this - seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) - seed_lo = (seed & 0xffffffff).to(tl.int32) + z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting + seed = seed.to(tl.uint64) + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) @@ -82,18 +75,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): @triton.jit def uint32_to_uniform_float(x): """ - Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). - This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly - covers all the possible values it can take. + Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. - x = tl.where(x < 0, -x - 1, x) - return x * max + two_to_the_minus_32 = 2.328306e-10 + return x * two_to_the_minus_32 + @triton.jit def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, + Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`U(0, 1)` :param seed: The seed for generating random numbers. @@ -102,6 +93,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) + @triton.jit def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ @@ -122,6 +114,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # randn # ------------------- + @triton.jit def pair_uniform_to_normal(u1, u2): """Box-Muller transform""" @@ -130,10 +123,11 @@ def pair_uniform_to_normal(u1, u2): r = tl.sqrt(-2.0 * tl.log(u1)) return r * tl.cos(th), r * tl.sin(th) + @triton.jit def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Given a :code:`seed` scalar and an :code:`offset` block, + Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. @@ -145,6 +139,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): n1, _ = pair_uniform_to_normal(u1, u2) return n1 + @triton.jit def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ From 6f7acad48fd94384f5b1fe0148e2c7f584a6aefd Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Thu, 6 Jan 2022 12:04:33 -0800 Subject: [PATCH 040/215] [CODEGEN] Avoid use of deprecated AST nodes (#418) Co-authored-by: hauntsaninja <> --- python/triton/code_gen.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8622333bf..90018028d 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -305,9 +305,6 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) - def visit_Str(self, node): - return ast.literal_eval(node) - def visit_Subscript(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) @@ -374,9 +371,6 @@ class CodeGenerator(ast.NodeVisitor): def visit_Index(self, node): return self.visit(node.value) - def visit_NameConstant(self, node): - return node.value - def visit_keyword(self, node): return {node.arg: self.visit(node.value)} @@ -393,8 +387,18 @@ class CodeGenerator(ast.NodeVisitor): return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) - def visit_Num(self, node): - return node.n + def visit_Constant(self, node): + return node.value + + if sys.version_info < (3, 8): + def visit_NameConstant(self, node): + return node.value + + def visit_Num(self, node): + return node.n + + def visit_Str(self, node): + return ast.literal_eval(node) def visit_Attribute(self, node): lhs = self.visit(node.value) From 8bf551ae7a982f3ac8083bdd407ad9ecc3f39b21 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Thu, 6 Jan 2022 14:34:17 -0800 Subject: [PATCH 041/215] [STYLE] run autopep8 and isort (#421) Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review. --- python/bench/bench_blocksparse.py | 53 ++--- python/bench/bench_cross_entropy.py | 25 +-- python/bench/bench_matmul.py | 22 +- python/bench/run.py | 5 +- python/setup.py | 35 ++-- python/test/regression/test_performance.py | 83 ++++---- python/test/unit/language/test_core.py | 152 ++++++++------ python/test/unit/language/test_random.py | 40 ++-- .../test/unit/operators/test_blocksparse.py | 8 +- .../test/unit/operators/test_cross_entropy.py | 20 +- python/test/unit/operators/test_matmul.py | 12 +- python/test/unit/runtime/test_cache.py | 23 ++- python/test/unit/runtime/test_comm.py | 10 +- python/triton/code_gen.py | 115 ++++++----- python/triton/language/core.py | 77 ++++--- python/triton/language/random.py | 1 - python/triton/ops/__init__.py | 4 +- python/triton/ops/blocksparse/__init__.py | 2 +- python/triton/ops/blocksparse/matmul.py | 134 ++++++------ python/triton/ops/blocksparse/softmax.py | 53 ++--- python/triton/ops/cross_entropy.py | 6 +- python/triton/ops/matmul.py | 52 ++--- python/triton/ops/matmul_perf_model.py | 193 +++++++++--------- python/triton/testing.py | 44 ++-- python/triton/tools/disasm.py | 2 +- python/tutorials/01-vector-add.py | 4 +- python/tutorials/02-fused-softmax.py | 11 +- python/tutorials/03-matrix-multiplication.py | 46 +++-- python/tutorials/04-low-memory-dropout.py | 18 +- python/tutorials/05-layer-norm.py | 115 ++++++----- 30 files changed, 742 insertions(+), 623 deletions(-) diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index b6eacd884..d678f49f8 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -1,4 +1,5 @@ import torch + import triton # ------------------------------- @@ -8,18 +9,18 @@ import triton nt = {False: 'n', True: 't'} square_confs = [ triton.testing.Benchmark( - x_names = ['M', 'N', 'K'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg = 'block', - line_vals = [16, 32, 64, 128], - line_names = ['Block16', 'Block32', 'Block64', 'Block128'], - ylabel = 'TFLOPS', - plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', - args = {'layout_mode': layout_mode, 'op_mode': op_mode, - 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} - )\ - for AT in [False] for BT in [False] \ - for op_mode in ['dsd'] for layout_mode in ['dense'] + x_names=['M', 'N', 'K'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], + line_arg='block', + line_vals=[16, 32, 64, 128], + line_names=['Block16', 'Block32', 'Block64', 'Block128'], + ylabel='TFLOPS', + plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', + args={'layout_mode': layout_mode, 'op_mode': op_mode, + 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} + ) + for AT in [False] for BT in [False] + for op_mode in ['dsd'] for layout_mode in ['dense'] ] @@ -27,7 +28,7 @@ square_confs = [ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): Z, H = 1, 1 make_layout = { - 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ + 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), }[layout_mode] # create layout @@ -45,10 +46,10 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) num_flops = { - 'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ - 'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ + 'sdd': 2 * Z * K * float(layout.sum()) * block * block, + 'dsd': 2 * Z * N * float(layout.sum()) * block * block, 'dds': 2 * Z * M * float(layout.sum()) * block * block - }[op_mode]*1e-12 + }[op_mode] * 1e-12 return tflops(mean_ms), tflops(min_ms), tflops(max_ms) @@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, square_confs = [ triton.testing.Benchmark( - x_names = ['M', 'N'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg = 'block', - line_vals = [16, 32, 64], - line_names = ['Block16', 'Block32', 'Block64'], - ylabel = 'GBPS', - plot_name = f'{layout_mode}-square', - args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} - )\ + x_names=['M', 'N'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], + line_arg='block', + line_vals=[16, 32, 64], + line_names=['Block16', 'Block32', 'Block64'], + ylabel='GBPS', + plot_name=f'{layout_mode}-square', + args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} + ) for layout_mode in ['dense', 'tril'] ] @@ -88,4 +89,4 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): return gbps(mean_ms), gbps(min_ms), gbps(max_ms) -bench_matmul.run(print_data=True, show_plots=True) \ No newline at end of file +bench_matmul.run(print_data=True, show_plots=True) diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py index 5347ae24a..aaa0e28f5 100644 --- a/python/bench/bench_cross_entropy.py +++ b/python/bench/bench_cross_entropy.py @@ -1,17 +1,18 @@ import torch + import triton confs = [ triton.testing.Benchmark( - x_names = ['N'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], - line_arg = 'provider', - line_vals = ['triton', 'torch'], - line_names = ['Triton', 'Torch'], - ylabel = 'GBPS', - plot_name = f'{mode}-2048', - args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} - )\ + x_names=['N'], + x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], + line_arg='provider', + line_vals=['triton', 'torch'], + line_names=['Triton', 'Torch'], + ylabel='GBPS', + plot_name=f'{mode}-2048', + args={'M': 2048, 'dtype': torch.float16, 'mode': mode} + ) for mode in ['forward', 'backward'] ] @@ -24,8 +25,8 @@ def bench_op(M, N, dtype, mode, provider): num_gb = (2 * x.numel() * x.element_size() * 1e-9) gbps = lambda ms: num_gb / ms * 1e3 # forward pass - op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ - 'triton': triton.ops.cross_entropy}[provider] + op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), + 'triton': triton.ops.cross_entropy}[provider] if mode == 'forward': mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) if mode == 'backward': @@ -37,4 +38,4 @@ def bench_op(M, N, dtype, mode, provider): if __name__ == '__main__': - bench_op.run(print_data=True) \ No newline at end of file + bench_op.run(print_data=True) diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 7e912be31..9db005da0 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -1,6 +1,6 @@ -import triton import torch -import os + +import triton def rounded_linspace(low, high, steps, div): @@ -29,16 +29,16 @@ square_confs = [ transformer_confs = [ triton.testing.Benchmark( x_names=[x], - x_vals = rounded_linspace(NK//16, NK, 32, 128), + x_vals=rounded_linspace(NK // 16, NK, 32, 128), line_arg="provider", line_vals=["cublas", "triton", "cutlass"], line_names=["cuBLAS", "Triton", "CUTLASS"], ylabel="TFLOPS", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", - args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} - ) for NK in [12288]\ - for i, x in enumerate(["N", "K"])\ - for M in [2048] + args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16} + ) for NK in [12288] + for i, x in enumerate(["N", "K"]) + for M in [2048] ] @@ -46,8 +46,10 @@ transformer_confs = [ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) - if AT: a = a.t() - if BT: b = b.t() + if AT: + a = a.t() + if BT: + b = b.t() num_flops = 2 * M * N * K tflops = lambda ms: 2. * M * N * K / ms * 1e-9 if provider == "cublas": @@ -61,6 +63,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): try: ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) return tflops(ms), tflops(max_ms), tflops(min_ms) - except: + except Exception: return None return None diff --git a/python/bench/run.py b/python/bench/run.py index c23884bb5..5e6e3b392 100644 --- a/python/bench/run.py +++ b/python/bench/run.py @@ -1,7 +1,8 @@ import argparse -import sys -import os import inspect +import os +import sys + import triton diff --git a/python/setup.py b/python/setup.py index 17db76093..28194f41e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,29 +1,28 @@ -import os -import re -import sys -import sysconfig -import platform -import subprocess import distutils -import glob -import tempfile -import shutil -from distutils.version import LooseVersion -from setuptools import setup, Extension, find_packages -from setuptools.command.build_ext import build_ext -from setuptools.command.test import test as TestCommand import distutils.spawn -import urllib.request +import os +import platform +import re +import shutil +import subprocess +import sys import tarfile +import tempfile +import urllib.request +from distutils.version import LooseVersion + +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + def get_llvm(): # tries to find system LLVM - versions = ['-11.0', '-11', '-11-64'] + versions = ['-11.0', '-11', '-11-64'] supported = ['llvm-config{v}'.format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] if paths: - return '', '' + return '', '' # download if nothing is installed name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' dir = '/tmp' @@ -32,7 +31,7 @@ def get_llvm(): if not os.path.exists(llvm_library_dir): try: shutil.rmtree(os.path.join(dir, name)) - except: + except Exception: pass url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name) print('downloading and extracting ' + url + '...') @@ -96,7 +95,7 @@ class CMakeBuild(build_ext): "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir, #'-DPYTHON_EXECUTABLE=' + sys.executable, - #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', + # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) ] diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index ce93786b8..012ff65d7 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,14 +1,18 @@ -from numpy import record -import torch -import triton +import triton.language as tl import subprocess import sys + import pytest +import torch +from numpy import record + +import triton ####################### # Utilities ####################### + def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -23,48 +27,51 @@ def nvsmi(attrs): ####################### matmul_data = { - # square - (256 , 256 , 256 ) : {'v100': 0.027}, - (512 , 512 , 512 ) : {'v100': 0.158}, - (1024, 1024, 1024 ) : {'v100': 0.466}, - (2048, 2048, 2048 ) : {'v100': 0.680}, - (4096, 4096, 4096 ) : {'v100': 0.831}, - (8192, 8192, 8192 ) : {'v100': 0.849}, - # tall-skinny - (16 , 1024, 1024 ) : {'v100': 0.0128}, - (16 , 4096, 4096 ) : {'v100': 0.0883}, - (16 , 8192, 8192 ) : {'v100': 0.101}, - (64 , 1024, 1024 ) : {'v100': 0.073}, - (64 , 4096, 4096 ) : {'v100': 0.270}, - (64 , 8192, 8192 ) : {'v100': 0.360}, - (1024, 64 , 1024 ) : {'v100': 0.0692}, - (4096, 64 , 4096 ) : {'v100': 0.264}, - (8192, 64 , 8192 ) : {'v100': 0.323}, -# # deep reductions -# (64 , 64 , 16384) : {'v100': 0.}, -# (64 , 64 , 65536) : {'v100': 0.}, -# (256 , 256 , 8192 ) : {'v100': 0.}, -# (256 , 256 , 32768) : {'v100': 0.}, + # square + (256, 256, 256): {'v100': 0.027}, + (512, 512, 512): {'v100': 0.158}, + (1024, 1024, 1024): {'v100': 0.466}, + (2048, 2048, 2048): {'v100': 0.680}, + (4096, 4096, 4096): {'v100': 0.831}, + (8192, 8192, 8192): {'v100': 0.849}, + # tall-skinny + (16, 1024, 1024): {'v100': 0.0128}, + (16, 4096, 4096): {'v100': 0.0883}, + (16, 8192, 8192): {'v100': 0.101}, + (64, 1024, 1024): {'v100': 0.073}, + (64, 4096, 4096): {'v100': 0.270}, + (64, 8192, 8192): {'v100': 0.360}, + (1024, 64, 1024): {'v100': 0.0692}, + (4096, 64, 4096): {'v100': 0.264}, + (8192, 64, 8192): {'v100': 0.323}, + # # deep reductions + # (64 , 64 , 16384) : {'v100': 0.}, + # (64 , 64 , 65536) : {'v100': 0.}, + # (256 , 256 , 8192 ) : {'v100': 0.}, + # (256 , 256 , 32768) : {'v100': 0.}, } + + @pytest.mark.parametrize('M, N, K', matmul_data.keys()) def test_matmul(M, N, K): ref_gpu_util = matmul_data[(M, N, K)]['v100'] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 - max_gpu_perf = 1e-6*80*8*128*cur_sm_clock + max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' a = torch.randn((M, K), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda') fn = lambda: triton.ops.matmul(a, b) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) - cur_gpu_perf = 2.*M*N*K/ms * 1e-9 + cur_gpu_perf = 2. * M * N * K / ms * 1e-9 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + ####################### # Element-Wise ####################### -import triton.language as tl + @triton.jit def _add(x_ptr, y_ptr, output_ptr, n_elements, @@ -80,21 +87,22 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { - 1024*16 : {'v100': 0.0219}, - 1024*64 : {'v100': 0.0791}, - 1024*256 : {'v100': 0.243}, - 1024*1024 : {'v100': 0.534}, - 1024*4096 : {'v100': 0.796}, - 1024*16384: {'v100': 0.905}, - 1024*65536: {'v100': 0.939}, + 1024 * 16: {'v100': 0.0219}, + 1024 * 64: {'v100': 0.0791}, + 1024 * 256: {'v100': 0.243}, + 1024 * 1024: {'v100': 0.534}, + 1024 * 4096: {'v100': 0.796}, + 1024 * 16384: {'v100': 0.905}, + 1024 * 65536: {'v100': 0.939}, } + @pytest.mark.parametrize('N', elementwise_data.keys()) def test_elementwise(N): ref_gpu_util = elementwise_data[N]['v100'] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 - max_gpu_perf = 512*2*ref_mem_clock*1e-3 + max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3 assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) @@ -102,7 +110,6 @@ def test_elementwise(N): grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) - cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 + cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 cur_gpu_util = cur_gpu_perf / max_gpu_perf triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) - diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 41c9e9236..7f0af78b4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -86,6 +86,7 @@ def patch_kernel(template, to_replace): @pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 + @triton.jit def kernel(X, SIZE: tl.constexpr): pass @@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'): def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): SIZE = 128 # define the kernel / launch-grid + @triton.jit def kernel(Z, X, SIZE: tl.constexpr): off = tl.arange(0, SIZE) @@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): SIZE = 128 # define the kernel / launch-grid + @triton.jit def kernel(Z, X, Y, SIZE: tl.constexpr): off = tl.arange(0, SIZE) @@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: # --------------- # test binary ops # --------------- + + @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) - for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes - for dtype_y in dtypes + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes + for dtype_y in dtypes ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' @@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y", - [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + - [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] -) + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] + ) def test_floordiv(dtype_x, dtype_y, device='cuda'): # Triton has IEEE, not numpy/torch, semantics for %, and those carry # through to //, so we have to use a nonstandard expression to get a @@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'): # test compare ops # --------------- ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \ -# real -[ - (dtype_x, dtype_y, op, 'real', 'real') \ - for op in ops \ - for dtype_x in dtypes \ - for dtype_y in dtypes -] + \ -# NaNs -[('float32', 'float32', op, mode_x, mode_y) \ - for op in ops - for mode_x, mode_y in [('nan' , 'real'), - ('real', 'nan'), - ('nan' , 'nan')] -]) + +@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", + # real + [ + (dtype_x, dtype_y, op, 'real', 'real') + for op in ops + for dtype_x in dtypes + for dtype_y in dtypes + ] + + # NaNs + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), + ('real', 'nan'), + ('nan', 'nan')] + + ]) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): expr = f'x {op} y' if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): @@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): # 'exp', 'log', 'cos', 'sin' # ]) + @pytest.mark.parametrize("expr", [ 'exp', 'log', 'cos', 'sin' ]) @@ -368,8 +376,8 @@ def make_ptr_str(name, shape): @pytest.mark.parametrize("expr, dtype_str", [ (f'x[{s}]', d) - for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] - for d in ['int32', 'uint32', 'uint16'] + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') @@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'): @triton.jit def fn(a, b): return a + b, \ - a - b, \ - a * b + a - b, \ + a * b def test_tuples(): @@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): # --------------- @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ (dtype_x, dtype_z, False) - for dtype_x in dtypes - for dtype_z in dtypes + for dtype_x in dtypes + for dtype_z in dtypes ] + [ ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), @@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): @triton.jit def kernel(X, Z, BITCAST: tl.constexpr): x = tl.load(X) - z = x.to(Z.dtype.element_ty, bitcast = BITCAST) + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) tl.store(Z, z) # triton result @@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # --------------- # test reduce # --------------- + + @pytest.mark.parametrize("dtype_str, shape", - [(dtype, shape) \ - for dtype in dtypes\ - for shape in [128, 512]]) + [(dtype, shape) + for dtype in dtypes + for shape in [128, 512]]) def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel @@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) - x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :]) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = tl.sum(x, axis=AXIS) tl.store(Z + range_m, z) # input @@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'): # --------------- # test permute # --------------- + + @pytest.mark.parametrize("dtype_str, shape, perm", - [(dtype, shape, perm) \ - for dtype in ['float32']\ - for shape in [(128, 128)]\ - for perm in [(1, 0)]]) + [(dtype, shape, perm) + for dtype in ['float32'] + for shape in [(128, 128)] + for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): # triton kernel @@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # test dot # --------------- + @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) def test_dot(epilogue, device='cuda'): # triton kernel @@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'): y_tri, y_tri.stride(0), y_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX = epilogue=='add-matrix', - ADD_ROWS = epilogue=='add-rows', - ADD_COLS = epilogue=='add-cols') + ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', + ADD_COLS=epilogue == 'add-cols') # torch result z_ref = np.matmul(x, y) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': - z_ref += z[:,0][:, None] + z_ref += z[:, 0][:, None] if epilogue == 'add-cols': - z_ref += z[0,:][None, :] + z_ref += z[0, :][None, :] # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized @@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'): assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + def test_dot_without_load(): @triton.jit def kernel(out): @@ -713,28 +727,30 @@ def test_dot_without_load(): b = tl.zeros((32, 32), tl.float32) c = tl.zeros((32, 32), tl.float32) c = tl.dot(a, b) - pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] + pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(pout, c) - out = torch.ones((32,32), dtype=torch.float32, device="cuda") + out = torch.ones((32, 32), dtype=torch.float32, device="cuda") kernel[(1,)](out) # --------------- # test arange # --------------- + @pytest.mark.parametrize("start", [0, 1, 7, 16]) def test_arange(start, device='cuda'): BLOCK = 128 z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + @triton.jit def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) - z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) + _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) triton.testing.assert_almost_equal(z_tri, z_ref) # --------------- @@ -742,6 +758,8 @@ def test_arange(start, device='cuda'): # --------------- # 'bfloat16': torch.bfloat16, # Testing masked loads with an intermate copy to shared memory run. + + @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): M = 32 @@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'): N_offsets = tl.arange(0, N) K_offsets = tl.arange(0, K) - in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:] - in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:] + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] # Load inputs. x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) @@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'): o = tl.dot(x, w) # Store output - output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:] + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) pgm = _kernel[(1,)](in1, in2, out, - in1.stride()[0], - in2.stride()[0], - out.stride()[0], - in1.numel(), - in2.numel(), - out.numel(), - M=M, N=N, K=K) + in1.stride()[0], + in2.stride()[0], + out.stride()[0], + in1.numel(), + in2.numel(), + out.numel(), + M=M, N=N, K=K) - reference_out =torch.matmul(in1, in2) + reference_out = torch.matmul(in1, in2) triton.testing.allclose(out, reference_out) + @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) def test_load_cache_modifier(cache): src = torch.empty(128, device='cuda') @@ -796,8 +815,8 @@ def test_load_cache_modifier(cache): @triton.jit def _kernel(dst, src, CACHE: tl.constexpr): offsets = tl.arange(0, 128) - x = tl.load(src+offsets, cache_modifier=CACHE) - tl.store(dst+offsets, x) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] @@ -830,11 +849,14 @@ def test_load_cache_modifier(cache): # --------------- # test default # --------------- -#TODO: can't be local to test_default +# TODO: can't be local to test_default + + @triton.jit -def _impl(value = 10): +def _impl(value=10): return value + def test_default(): value = 5 ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') @@ -851,7 +873,9 @@ def test_default(): # --------------- # test noop -#---------------- +# ---------------- + + def test_noop(device='cuda'): @triton.jit def kernel(x): @@ -861,9 +885,9 @@ def test_noop(device='cuda'): @pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), - (2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') ]) def test_value_specialization(value: int, value_type: str, device='cuda') -> None: diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 67173adfb..82ae7f0c2 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -1,16 +1,17 @@ -import torch -import triton -import triton.language as tl +import numpy as np import pytest import scipy.stats -import numpy as np - +import torch from numpy.random import Philox +import triton +import triton.language as tl + ##################################### -## Reference Philox Implementation +# Reference Philox Implementation ##################################### + class PhiloxConfig: def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) @@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x): ##################################### -## Unit Tests +# Unit Tests ##################################### BLOCK = 1024 # test generation of random uint32 + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in ['10', '4,53', '10000']\ - for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] -) + [(size, seed) for size in ['10', '4,53', '10000'] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] + ) def test_randint(size, seed, device='cuda'): size = list(map(int, size.split(','))) + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) @@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'): assert out_tri == out_ref # test uniform PRNG + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000]\ - for seed in [0, 42, 124, 54]] -) + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) def test_rand(size, seed, device='cuda'): @triton.jit def kernel(X, N, seed): @@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'): assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 # test normal PRNG + + @pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000]\ - for seed in [0, 42, 124, 54]] -) + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) def test_randn(size, seed, device='cuda'): @triton.jit def kernel(X, N, seed): diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index b9cdc23c7..ed569c04d 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -1,6 +1,7 @@ -import torch -import triton import pytest +import torch + +import triton @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) @@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE): # torch result rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) # broadcast at_mask to the same shape as rx - if is_causal: at_mask = torch.tril(at_mask) + if is_causal: + at_mask = torch.tril(at_mask) M = at_mask[None, None, :, :] + torch.zeros_like(rx) rx[M == 0] = float("-inf") # rx += kp_mask[:, None, None, :] diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 48cb303bb..08516257b 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -1,14 +1,16 @@ -import torch -import triton import pytest +import torch + +import triton + @pytest.mark.parametrize("M, N, dtype, mode", - [ - (M, N, dtype, mode) for M in [1024, 821] - for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32']\ - for mode in ['forward', 'backward'] - ] + [ + (M, N, dtype, mode) for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] + ] ) def test_op(M, N, dtype, mode): dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] @@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode): x.grad.zero_() th_y.backward(dy) th_dx = x.grad.clone() - triton.testing.assert_almost_equal(th_dx, tt_dx) \ No newline at end of file + triton.testing.assert_almost_equal(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 75241c291..1d413a0e6 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -1,8 +1,10 @@ -import pytest import itertools -import triton + +import pytest import torch +import triton + @pytest.mark.parametrize( "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", @@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, K = BLOCK_K * SPLIT_K if K is None else K # allocate/transpose inputs DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] - a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) - b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) + a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) + b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a b = b.t() if BT else b # run test th_c = torch.matmul(a, b) - tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) + tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest) triton.testing.assert_almost_equal(th_c, tt_c) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 3ad387f09..51c69b5b6 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,13 +1,16 @@ -import torch -import triton -from triton.code_gen import JITFunction -import triton.language as tl import os import shutil + import pytest +import torch + +import triton +import triton.language as tl +from triton.code_gen import JITFunction tmpdir = ".tmp" + @triton.jit def function_1(i): i = i + 1 @@ -20,18 +23,21 @@ def function_2(i): i = i + 1 return i + @triton.jit def kernel(X, i, BLOCK: tl.constexpr): i = i + 1 i = function_1(i) tl.store(X, i) + @triton.jit(do_not_specialize=["i"]) def kernel_nospec(X, i, BLOCK: tl.constexpr): i = i + 1 i = function_1(i) tl.store(X, i) + def apply_src_change(target, old, new): delattr(kernel.fn, 'hash') delattr(function_1.fn, 'hash') @@ -42,28 +48,34 @@ def apply_src_change(target, old, new): target.src = target.src.replace(new, old) return ret + def test_nochange(): baseline = kernel.cache_key updated = apply_src_change(kernel, 'i + 1', 'i + 1') assert baseline == updated + def test_toplevel_change(): baseline = kernel.cache_key updated = apply_src_change(kernel, 'i + 1', 'i + 2') assert baseline != updated + def test_nested1_change(): baseline = kernel.cache_key updated = apply_src_change(function_1, 'i + 1', 'i + 2') assert baseline != updated + def reset_tmp_dir(): os.environ["TRITON_CACHE_DIR"] = tmpdir if os.path.exists(tmpdir): shutil.rmtree(tmpdir) + def test_reuse(): counter = 0 + def inc_counter(key, binary, repr): nonlocal counter counter += 1 @@ -73,11 +85,12 @@ def test_reuse(): for i in range(10): kernel[(1,)](x, 1, BLOCK=1024) assert counter == 1 - + @pytest.mark.parametrize('mode', ['enable', 'disable']) def test_specialize(mode): counter = 0 + def inc_counter(key, binary, repr): nonlocal counter counter += 1 diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py index ae843a15f..6d0658f3b 100644 --- a/python/test/unit/runtime/test_comm.py +++ b/python/test/unit/runtime/test_comm.py @@ -1,9 +1,11 @@ -import torch -import triton -import pytest import subprocess -import triton.language as tl + import numpy as np +import pytest +import torch + +import triton +import triton.language as tl def get_p2p_matrix(): diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index eec36f052..439c1798e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,26 +1,26 @@ import ast import builtins +import dbm import functools -import inspect -import struct -import sys -import textwrap import hashlib +import inspect import os import pickle +import struct import subprocess -import os +import sys +import tempfile +import textwrap +import time import warnings -from .tools.disasm import extract +from typing import Dict, Optional + import torch +from filelock import FileLock + import triton import triton._C.libtriton.triton as _triton -from filelock import FileLock -import dbm -import tempfile -from typing import Optional, Dict -import time - +from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): @@ -100,7 +100,7 @@ class CodeGenerator(ast.NodeVisitor): arg_names, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): - arg_node = node.args.args[-i-1] + arg_node = node.args.args[-i - 1] annotation = arg_node.annotation name = arg_node.arg st_target = ast.Name(id=name, ctx=ast.Store()) @@ -134,8 +134,7 @@ class CodeGenerator(ast.NodeVisitor): fn.args[idx].name = arg_name arg_values.append(fn.args[idx]) idx += 1 - - + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: @@ -178,7 +177,6 @@ class CodeGenerator(ast.NodeVisitor): # default: call visit_Assign return self.visit_Assign(node) - def visit_Assign(self, node): _names = [] for target in node.targets: @@ -272,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor): if else_bb: self.builder.set_insert_block(else_bb) is_terminator = self.visit_compound_statement(node.orelse) - #TODO: last statement is a terminator? + # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) self.module.seal_block(endif_bb) @@ -404,10 +402,10 @@ class CodeGenerator(ast.NodeVisitor): pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) - build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ - self.visit(pos_cond_node),\ - self.visit(neg_cond_node),\ - _builder=self.builder) + build_cond = lambda: triton.language.where(self.visit(pos_step_node), + self.visit(pos_cond_node), + self.visit(neg_cond_node), + _builder=self.builder) #cond_node = neg_cond_node step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation @@ -462,7 +460,7 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core: + sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) @@ -505,10 +503,10 @@ class Binary: class LoadedBinary: def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, + module, kernel = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, device) self.bin = bin self.asm = bin.asm @@ -520,8 +518,8 @@ class LoadedBinary: def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, + grid_0, grid_1, grid_2, + self.bin.num_warps * 32, 1, 1, args, self.bin.shared_mem) def get_sass(self, fun=None): @@ -632,10 +630,14 @@ class Kernel: @staticmethod def pow2_divisor(N): - if N % 16 == 0: return 16 - if N % 8 == 0: return 8 - if N % 4 == 0: return 4 - if N % 2 == 0: return 2 + if N % 16 == 0: + return 16 + if N % 8 == 0: + return 8 + if N % 4 == 0: + return 4 + if N % 2 == 0: + return 2 return 1 def __init__(self, fn): @@ -675,7 +677,7 @@ class Kernel: tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ + attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int) and i not in self.fn.do_not_specialize} # transforms ints whose value is one into constants for just-in-time compilation @@ -705,7 +707,7 @@ class Kernel: if binary is None: binary = self._compile( *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, + num_warps=num_warps, num_stages=num_stages, constants=constants, ) if bin_cache_path: @@ -766,13 +768,12 @@ class Launcher: def __call__(self, *wargs, **kwargs): return self.kernel(*wargs, **kwargs, grid=self.grid) - class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): + def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: + :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. @@ -788,6 +789,7 @@ class Autotuner: self.hook = lambda args: 0 if reset_to_zero is not None: self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + def _hook(args): for i in self.reset_idx: args[i].zero_() @@ -802,7 +804,7 @@ class Autotuner: perf_model, top_k, prune_num_stages_by = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.prune_num_stages_by = prune_num_stages_by - + def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -814,6 +816,7 @@ class Autotuner: ) # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) + def kernel_call(): if config.pre_hook: config.pre_hook(self.nargs) @@ -836,9 +839,9 @@ class Autotuner: top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) \ + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start @@ -876,7 +879,7 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -#########################3 +# 3 class DependenciesFinder(ast.NodeVisitor): @@ -888,7 +891,7 @@ class DependenciesFinder(ast.NodeVisitor): def visit_Name(self, node): return self.globals.get(node.id, None) - + def visit_Attribute(self, node): lhs = self.visit(node.value) while isinstance(lhs, ast.Attribute): @@ -917,10 +920,10 @@ class DependenciesFinder(ast.NodeVisitor): self.ret = (self.ret + func.hash).encode("utf-8") self.ret = hashlib.md5(self.ret).hexdigest() -class JITFunction: - - cache_hook = None +class JITFunction: + + cache_hook = None def __init__(self, fn, version=None, do_not_specialize=None): # information of wrapped function @@ -946,7 +949,6 @@ class JITFunction: # forward docs self.__doc__ = fn.__doc__ - @property @functools.lru_cache() def cache_key(self): @@ -1027,6 +1029,7 @@ class Config: :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this function are args. """ + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps @@ -1049,19 +1052,19 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): .. highlight:: python .. code-block:: python - @triton.autotune(configs=[ + @triton.autotune(configs=[ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], + ], key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes + # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] - + :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. + This means that whatever value the kernel updates will be updated multiple times. To avoid this undesired behavior, you can use the `reset_to_zero` argument, which reset the value of the provided tensor to `zero` before running any configuration. @@ -1069,7 +1072,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): :type configs: list[triton.Config] :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: + :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. @@ -1099,7 +1102,7 @@ def heuristics(values): def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. each such function takes a list of positional arguments as input. .type values: dict[str, Callable[[list[Any]], Any]] @@ -1150,6 +1153,7 @@ def jit(*args, **kwargs): def cdiv(x, y): return (x + y - 1) // y + def next_power_of_2(n): """Return the smallest power of 2 greater than or equal to n""" n -= 1 @@ -1163,13 +1167,14 @@ def next_power_of_2(n): ###### + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype - self.base = base + self.base = base self.is_cuda = base.is_cuda self.device = base.device - + def data_ptr(self): return self.base.data_ptr() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 210a72a30..6895c101c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,8 +1,8 @@ -import triton -from triton._C.libtriton.triton import ir -from triton._C.libtriton.triton import frontend from functools import wraps +import triton +from triton._C.libtriton.triton import frontend, ir + # convert block/dtype to ir values def _to_ir(x, builder): @@ -65,7 +65,7 @@ def builtin(fn): def wrapper(*args, **kwargs): if '_builder' not in kwargs or \ kwargs['_builder'] is None: - raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) return wrapper @@ -111,6 +111,7 @@ class pointer_dtype: def __str__(self): return f'pointer<{self.element_ty}>' + # scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8) @@ -331,27 +332,27 @@ class constexpr: def __rsub__(self, other): return other.value - self.value - + def __mul__(self, other): return self.value * other.value def __rmul__(self, other): return other.value * self.value - + def __truediv__(self, other): return self.value / other.value - + def __rtruediv__(self, other): return other.value / self.value - + def __floordiv__(self, other): return self.value // other.value - + def __rfloordiv__(self, other): return other.value // self.value # - + def __gt__(self, other): return self.value > other.value @@ -360,25 +361,25 @@ class constexpr: def __ge__(self, other): return self.value >= other.value - + def __rge__(self, other): return other.value >= self.value - + def __lt__(self, other): return self.value < other.value def __rlt__(self, other): return other.value < self.value - + def __le__(self, other): return self.value <= other.value - + def __rle__(self, other): return other.value <= self.value - + def __eq__(self, other): return self.value == other.value - + def __ne__(self, other): return self.value != other.value @@ -489,15 +490,16 @@ def broadcast_to(input, shape, _builder=None): """ return frontend.broadcast_to(input, shape, _builder) + @builtin def cat(input, other, _builder=None): """ Concatenate the given blocks :param input: The first input block. - :type input: + :type input: :param other: The second input block. - :type other: + :type other: """ return frontend.cat(input, other, _builder) @@ -508,7 +510,7 @@ def reshape(input, shape, _builder=None): Tries to reshape the given block to a new shape. :param input: The input block. - :type input: + :type input: :param shape: The desired shape. :type shape: Tuple[int] @@ -546,7 +548,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. - :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. + :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. :code:`other` is implicitly typecast to :code:`pointer.dtype.element_ty`. @@ -565,7 +567,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. @@ -600,9 +602,10 @@ def _add_atomic_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator - + + @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): @@ -614,6 +617,7 @@ def atomic_cas(pointer, cmp, val, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None): return frontend.atomic_xchg(pointer, val, mask, _builder) + @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): @@ -683,6 +687,7 @@ def where(condition, x, y, _builder=None): def umulhi(x, y, _builder=None): return frontend.umulhi(x, y, _builder) + def _add_math_1arg_docstr(name): def _decorator(func): @@ -694,24 +699,28 @@ def _add_math_1arg_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator + @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): return frontend.exp(x, _builder) + @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): return frontend.log(x, _builder) + @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): return frontend.cos(x, _builder) + @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): @@ -739,9 +748,10 @@ def _add_reduction_docstr(name): """ func.__doc__ = docstr.format(name=name) return func - + return _decorator + @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): @@ -759,6 +769,7 @@ def min(input, axis, _builder=None): def sum(input, axis, _builder=None): return frontend.sum(input, axis, _builder) + @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): @@ -778,7 +789,7 @@ def debug_barrier(_builder=None): @builtin def multiple_of(input, value, _builder=None): """ - Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. + Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ return frontend.multiple_of(input, value, _builder) @@ -786,7 +797,7 @@ def multiple_of(input, value, _builder=None): @builtin def max_contiguous(input, value, _builder=None): """ - Let the compiler knows that the `value` first values in :code:`input` are contiguous. + Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ return frontend.max_contiguous(input, value, _builder) @@ -794,7 +805,7 @@ def max_contiguous(input, value, _builder=None): @builtin def max_contiguous(input, value, _builder=None): """ - Let the compiler knows that the `value` first values in :code:`input` are contiguous. + Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ return frontend.max_contiguous(input, value, _builder) @@ -807,6 +818,7 @@ def max_contiguous(input, value, _builder=None): def abs(x): return where(x >= 0, x, -x) + @triton.jit def cdiv(x, div): """ @@ -871,13 +883,14 @@ def ravel(x): """ return triton.language.reshape(x, [x.type.numel]) + @triton.jit def swizzle2d(i, j, size_i, size_j, size_g): """ transformes indices of a row-major size_i*size_j matrix into those of one where indices are row major for each group of size_j rows. For example, for size_i = size_j = 4 and size_g = 2, it will transform - [[0 , 1 , 2 , 3 ], + [[0 , 1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 , 9 , 10, 11], [12, 13, 14, 15]] @@ -888,16 +901,16 @@ def swizzle2d(i, j, size_i, size_j, size_g): [9, 11, 13, 15]] """ # "unrolled index in array" - ij = i*size_j + j + ij = i * size_j + j # number of elements in `size_g` groups # of `size_j` columns size_gj = size_g * size_j # index of the group in which (i,j) is - group_id = ij // size_gj + group_id = ij // size_gj # row-index of the first element of this group off_i = group_id * size_g # last group may have fewer rows - size_g = minimum(size_i - off_i, size_g) + size_g = minimum(size_i - off_i, size_g) # new row and column indices new_i = off_i + (ij % size_g) new_j = (ij % size_gj) // size_g diff --git a/python/triton/language/random.py b/python/triton/language/random.py index cb2ddfc6b..6f3645b41 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -1,7 +1,6 @@ import triton from . import core as tl - PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index ca6ca61f8..7d27ffd20 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,4 +1,4 @@ #from .conv import _conv, conv -from .matmul import _matmul, matmul +from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy -from . import blocksparse \ No newline at end of file +from .matmul import _matmul, matmul diff --git a/python/triton/ops/blocksparse/__init__.py b/python/triton/ops/blocksparse/__init__.py index c8da856aa..231c27a1f 100644 --- a/python/triton/ops/blocksparse/__init__.py +++ b/python/triton/ops/blocksparse/__init__.py @@ -1,2 +1,2 @@ from .matmul import matmul -from .softmax import softmax \ No newline at end of file +from .softmax import softmax diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index ce15c9af4..15e6c0523 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,6 +1,7 @@ +import torch + import triton import triton.language as tl -import torch # ******************************************************** # -------------------------------------------------------- @@ -11,16 +12,17 @@ import torch # -------------------------------------------------------- # ******************************************************** + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @triton.jit def _sdd_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ak, - stride_zb, stride_hb, stride_bk, stride_nb, - stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, + A, B, C, + stride_za, stride_ha, stride_ma, stride_ak, + stride_zb, stride_hb, stride_bk, stride_nb, + stride_zc, stride_hc, stride_mc, stride_nc, + K, grid_offset, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): @@ -30,25 +32,25 @@ def _sdd_kernel( block_id = tl.program_id(1) + grid_offset lut += block_id * 3 # offsets - off_z = tl.program_id(2) # batch - off_h = tl.load(lut + 0) # head - + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + # initialize pointers to A start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + (off_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ak) + a_ptrs = A + (off_z * stride_za + + off_h * stride_ha + + offs_am[:, None] * stride_ma + + offs_ak[None, :] * stride_ak) # initialize pointers to B start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + (off_z * stride_zb \ - + off_h * stride_hb \ - + offs_bn[None, :] * stride_nb \ - + offs_bk[:, None] * stride_bk) + b_ptrs = B + (off_z * stride_zb + + off_h * stride_hb + + offs_bn[None, :] * stride_nb + + offs_bk[:, None] * stride_bk) ## ---------------- ## ## Inner Loop ## ## ---------------- ## @@ -69,13 +71,14 @@ def _sdd_kernel( ## ---------------- ## offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + (off_z * stride_zc \ - + block_id * stride_hc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc) + pc = C + (off_z * stride_zc + + block_id * stride_hc + + offs_cm[:, None] * stride_mc + + offs_cn[None, :] * stride_nc) tl.store(pc, c, mask=True) -def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None): + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -103,7 +106,7 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), Ka, 0, lut, - TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4, + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, num_warps=4, ) return c @@ -119,50 +122,52 @@ def sdd_lut(layout, block, device): # This operation uses a look-up table that contains pre-computed pointer increments # in order to minimize computations in the inner loop of the matmul kernel. # ----------------------------- + + @triton.jit def _dsd_kernel( - A, B, C, - stride_az, stride_ha, stride_am, stride_ak, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, + A, B, C, + stride_az, stride_ha, stride_am, stride_ak, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_cm, stride_cn, + DS0, DS1, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): #------------# #- Prologue -# #------------# - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) - pidz = tl.program_id(2) + pidz = tl.program_id(2) header = lut + pid_n * 4 offset = tl.load(header + 0) - K = tl.load(header + 1) + K = tl.load(header + 1) column = tl.load(header + 2) - off_h = tl.load(header + 3) - pinc = lut + offset + off_h = tl.load(header + 3) + pinc = lut + offset # initialize pointers to A (sparse) - block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) # compiler hint + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint offs_am = tl.arange(0, TILE_M) offs_ak = tl.arange(0, TILE_K) pa = A + pidz * stride_az \ - + block_id * stride_ha \ - + offs_am[:, None] * stride_am \ - + offs_ak[None, :] * stride_ak + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak # initialize pointers to B (dense) - offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) start_bk = tl.load(pinc) start_bk = tl.multiple_of(start_bk, 8) # compiler hint - offs_bk = start_bk + tl.arange(0, TILE_K) + offs_bk = start_bk + tl.arange(0, TILE_K) pb = B + pidz * stride_zb \ - + off_h * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk ## ---------------- ## ## Inner Loop ## ## ---------------- ## @@ -177,7 +182,7 @@ def _dsd_kernel( b = tl.load(pb, mask=offs_bn[None, :] < DS0) acc += tl.dot(a, b) pa += inc_a - pb += inc_b*stride_bk + pb += inc_b * stride_bk pinc += 2 inc_a = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) @@ -185,15 +190,16 @@ def _dsd_kernel( inc_b = tl.multiple_of(inc_b, 8) c = acc.to(C.dtype.element_ty) # initialize pointers to C - offs_cm = column*TILE_M + tl.arange(0, TILE_M) - offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N) + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) pc = C + off_h * stride_hc \ - + pidz * stride_zc \ - + offs_cm[:, None] * stride_cm \ - + offs_cn[None, :] * stride_cn - tl.store(pc, c, mask = offs_cn[None, :] < DS0) + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) -def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -231,7 +237,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = # exit() return c -def dsd_lut(layout, block, step, trans, device): +def dsd_lut(layout, block, step, trans, device): sizes = torch.sum(layout, 2 if trans else 1) head_id, col_id = sizes.nonzero(as_tuple=True) sizes = sizes.flatten() @@ -313,11 +319,11 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- @triton.jit def _dds_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ka, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, + A, B, C, + stride_za, stride_ha, stride_ma, stride_ka, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_mc, stride_nc, + DS0, DS1, lut, TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): @@ -348,7 +354,7 @@ def _dds_kernel( + offs_ak[None, :] * stride_ka # initialize pointers to B (sparse) block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) + block_id = tl.multiple_of(block_id, 8) offs_bn = tl.arange(0, TILE_N) offs_bk = tl.arange(0, TILE_K) ptrs_b = B + pid_z * stride_zb \ @@ -429,7 +435,7 @@ class _matmul(torch.autograd.Function): @staticmethod def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, db_width, out ): c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) @@ -499,10 +505,10 @@ class matmul: def __call__(self, a, b, out = None): c = _matmul.apply( - a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, - self.c_lut, self.c_width, - self.da_lut, self.da_width, + a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, + self.c_lut, self.c_width, + self.da_lut, self.da_width, self.db_lut, self.db_width, out ) - return c \ No newline at end of file + return c diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index dcf77afc8..f9d49ae56 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -1,7 +1,8 @@ -import triton.language as tl -import triton import torch +import triton +import triton.language as tl + def num_warps(n): if n < 512: @@ -33,10 +34,10 @@ def _forward( check = rbn < size rbmn = tl.where(check, rbn, size - 1) # block id and column id - blockid = tl.load(LUT + offset + rbmn * 4 + 0) + blockid = tl.load(LUT + offset + rbmn * 4 + 0) columnid = tl.load(LUT + offset + rbmn * 4 + 1) - rowid = tl.load(LUT + offset + rbmn * 4 + 2) - headid = tl.load(LUT + offset + rbmn * 4 + 3) + rowid = tl.load(LUT + offset + rbmn * 4 + 2) + headid = tl.load(LUT + offset + rbmn * 4 + 3) # pointers to X px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn x = tl.load(px, mask=check, other=-float('inf')) @@ -64,7 +65,7 @@ def _forward( attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # apply causal mask - is_in_upper_triangle = columnid*BLOCK + rxn > rowid*BLOCK + rxm + is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) # computation x = tl.softmax(x) @@ -127,9 +128,9 @@ class _softmax(torch.autograd.Function): @staticmethod def forward( - ctx, x, scale, rpe, - key_padding_mask, attn_mask, - kp_mask_mode, attn_mask_mode, + ctx, x, scale, rpe, + key_padding_mask, attn_mask, + kp_mask_mode, attn_mask_mode, is_causal, spdims, block, lut, maxlut ): @@ -161,15 +162,15 @@ class _softmax(torch.autograd.Function): # run kernel M = x.shape[0] grid = [spdims[0] * spdims[1] * block, M] - _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),\ - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - BLOCK = block, - APPLY_SCALE = apply_scale, - APPLY_RPE = apply_rpe, - APPLY_KP_MASK = apply_kp_mask, - APPLY_ATTN_MASK = apply_attn_mask, - KP_MASK_MUL = (kp_mask_mode == 'mul'), - ATTN_MASK_MUL = (attn_mask_mode == 'mul')) + _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0), + stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + BLOCK=block, + APPLY_SCALE=apply_scale, + APPLY_RPE=apply_rpe, + APPLY_KP_MASK=apply_kp_mask, + APPLY_ATTN_MASK=apply_attn_mask, + KP_MASK_MUL=(kp_mask_mode == 'mul'), + ATTN_MASK_MUL=(attn_mask_mode == 'mul')) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) @@ -211,10 +212,10 @@ class softmax: self.lut_cache = dict() def __call__( - self, x, scale=1., rpe=None, - key_padding_mask=None, attn_mask=None, + self, x, scale=1., rpe=None, + key_padding_mask=None, attn_mask=None, key_padding_mask_mode='add', attn_mask_mode='add', - is_causal = False + is_causal=False ): if rpe is not None and rpe.dtype != x.dtype: raise ValueError('relative position embedding must be %s' % x.dtype) @@ -224,11 +225,11 @@ class softmax: raise ValueError('Key padding mask must be %s' % x.dtype) lut, maxlut = self.make_lut(x.device) x = _softmax.apply( - x, scale, rpe, - key_padding_mask, attn_mask, - key_padding_mask_mode, attn_mask_mode, + x, scale, rpe, + key_padding_mask, attn_mask, + key_padding_mask_mode, attn_mask_mode, is_causal, - self.spdims, self.block, + self.spdims, self.block, lut, maxlut ) - return x \ No newline at end of file + return x diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 529b6c675..dfd4f4487 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -1,7 +1,9 @@ import os + +import torch + import triton import triton.language as tl -import torch def next_power_of_2(n): @@ -104,4 +106,4 @@ class _cross_entropy(torch.autograd.Function): return neg_logprobs, None -cross_entropy = _cross_entropy.apply \ No newline at end of file +cross_entropy = _cross_entropy.apply diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 8b7299a8b..60ecc9f3b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,11 +1,14 @@ import torch -import triton.language as tl + import triton +import triton.language as tl from .matmul_perf_model import * + def init_to_zero(name): return lambda nargs: nargs[name].zero_() + def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -14,14 +17,15 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @@ -30,26 +34,26 @@ def get_configs_io_bound(): # basic configs for compute-bound matmuls triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ] + get_configs_io_bound(), key=['M', 'N', 'K'], prune_configs_by={ - 'prune_num_stages_by' : prune_num_stages, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + 'prune_num_stages_by': prune_num_stages, + 'perf_model': estimate_matmul_time, + 'top_k': 10 }, ) @triton.jit -def _kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, +def _kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): # matrix multiplication @@ -68,12 +72,12 @@ def _kernel(A, B, C, M, N, K, rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(K, 0, -BLOCK_K*SPLIT_K): + for k in range(K, 0, -BLOCK_K * SPLIT_K): if EVEN_K: a = tl.load(A) b = tl.load(B) @@ -117,10 +121,10 @@ class _matmul(torch.autograd.Function): c = torch.empty((M, N), device=device, dtype=a.dtype) # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), + _kernel[grid](a, b, c, M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), GROUP_M=8) return c diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 16667a7b1..af4f3eed8 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -1,116 +1,121 @@ +import heapq + import torch + import triton import triton._C.libtriton.triton as _triton from triton.testing import get_dram_gbps, get_max_tensorcore_tflops -import heapq + def get_tensorcore_tflops(backend, device, num_ctas, num_warps): - ''' return compute throughput in TOPS ''' - total_warps = num_ctas * min(num_warps, 4) - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) - return tflops + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) + return tflops + def estimate_matmul_time( - # backend, device, - num_warps, num_stages, - M, N, K, - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, - debug=False, **kwargs + # backend, device, + num_warps, num_stages, + M, N, K, + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, + debug=False, **kwargs ): - ''' return estimated running time in ms - = max(compute, loading) + store ''' - backend = _triton.runtime.backend.CUDA - device = torch.cuda.current_device() + ''' return estimated running time in ms + = max(compute, loading) + store ''' + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() - num_cta_m = triton.cdiv(M, BLOCK_M) - num_cta_n = triton.cdiv(N, BLOCK_N) - num_cta_k = SPLIT_K - num_ctas = num_cta_m * num_cta_n * num_cta_k + num_cta_m = triton.cdiv(M, BLOCK_M) + num_cta_n = triton.cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k - # If the input is smaller than the block size - M, N = max(M, BLOCK_M), max(N, BLOCK_N) + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) - # time to compute - total_ops = 2*M*N*K / (1024*1024*1024) # GOPS - tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) - compute_ms = total_ops / tput + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + compute_ms = total_ops / tput - # time to load data - num_sm = _triton.runtime.num_sm(backend, device) - active_cta_ratio = min(1, num_ctas/num_sm) - active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate - active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% - dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s - l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) - # assume 80% of (following) loads are in L2 cache - load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) - load_a_l2 = M*K*2*0.8*(num_cta_n-1) - load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) - load_b_l2 = N*K*2*0.8*(num_cta_m-1) - # total - total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB - total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) - # loading time in ms - load_ms = total_dram/dram_bw + total_l2/l2_bw + # time to load data + num_sm = _triton.runtime.num_sm(backend, device) + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2) + load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw - # estimate storing time - store_bw = dram_bw * 0.6 # :o - store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB - if SPLIT_K == 1: - store_ms = store_c_dram /store_bw - else: - reduce_bw = store_bw - store_ms = store_c_dram/reduce_bw - # c.zero_() - zero_ms = M*N*2/(1024*1024)/store_bw - store_ms += zero_ms + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms - total_time_ms = max(compute_ms, load_ms) + store_ms - if debug: - print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' - f'loading time: {load_ms}ms, store time: {store_ms}ms, ' - f'Activate CTAs: {active_cta_ratio*100}%') - return total_time_ms def prune_num_stages(configs): - backend = _triton.runtime.backend.CUDA - device = torch.cuda.current_device() - cc = _triton.runtime.cc(backend, device) - # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + cc = _triton.runtime.cc(backend, device) + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages - # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) - configs_map = {} - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ - kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages - - key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) - if key in configs_map: - configs_map[key].append((config, num_stages)) - else: - configs_map[key] = [(config, num_stages)] + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages - pruned_configs = [] - for k, v in configs_map.items(): - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k - if cc >= 80: - # compute cycles (only works for ampere GPUs) - mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16) - mma_cycles = mmas/min(4, num_warps) * 8 + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] - ldgsts_latency = 300 # Does this matter? - optimal_num_stages = ldgsts_latency/mma_cycles + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if cc >= 80: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 - # nearest stages, prefer large #stages - nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ - if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles - for n in nearest: - pruned_configs.append(n[0]) - else: # Volta & Turing only supports num_stages <= 2 - random_config = v[0][0] - random_config.num_stages = 2 - pruned_configs.append(random_config) - return pruned_configs \ No newline at end of file + # nearest stages, prefer large #stages + nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/python/triton/testing.py b/python/triton/testing.py index eef7f5be6..310e754ed 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,10 +1,11 @@ -import torch import os -import triton._C.libtriton.triton as _triton -from .code_gen import OutOfResources import subprocess import sys +import torch + +import triton._C.libtriton.triton as _triton +from .code_gen import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass @@ -13,6 +14,7 @@ except ImportError: _cutlass = None has_cutlass = False + def catch_oor(kernel, pytest_handle=None): try: res = kernel() @@ -42,11 +44,11 @@ def cutlass_matmul(a, b): c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device) # run function dtype = str(a.dtype).split('.')[-1] - _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \ - M, N, Ka,\ - a.stride(0), a.stride(1),\ - b.stride(0), b.stride(1),\ - c.stride(0), c.stride(1),\ + _cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), + M, N, Ka, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), dtype, dtype, dtype, a.device.index, torch.cuda.current_stream(a.device).cuda_stream) @@ -59,6 +61,7 @@ def mask_tensor(x, mask, block, value=0): ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value return ret + def assert_almost_equal(x, y, decimal=2, err_msg=''): import numpy.testing as npt if isinstance(x, torch.Tensor): @@ -93,6 +96,7 @@ def nvsmi(attrs): ret = [int(x) for x in ret] return ret + def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -122,13 +126,13 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0 torch.cuda.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat - n_warmup = max(1, int(warmup/estimate_ms)) - n_repeat = max(1, int(rep/estimate_ms)) + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 # doesn't contain any input data before the run start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') # Warm-up for _ in range(n_warmup): @@ -161,6 +165,7 @@ class Benchmark: """ This class is used by the :code:`perf_report` function to generate line plots with a concise API. """ + def __init__( self, x_names, @@ -224,9 +229,10 @@ class Mark: self.benchmarks = benchmarks def _run(self, bench, save_path, show_plots, print_data): + import os + import matplotlib.pyplot as plt import pandas as pd - import os y_mean = bench.line_names y_min = [f'{x}-min' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names] @@ -259,7 +265,7 @@ class Mark: xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) ax.set_xlabel(xlabel) ax.set_ylabel(bench.ylabel) - #ax.set_title(bench.plot_name) + # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -297,6 +303,7 @@ def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) return wrapper + def get_dram_gbps(backend=None, device=None): ''' return DRAM bandwidth in GB/s ''' # assert backend == CUDA @@ -306,17 +313,18 @@ def get_dram_gbps(backend=None, device=None): device = torch.cuda.current_device() mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device) - bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s return bw_gbps + def get_max_tensorcore_tflops(backend, device): - num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz # assume fp32 += fp16*fp16 cc = _triton.runtime.cc(backend, device) if cc < 80: - ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: ops_per_sub_core = 512 - tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) + tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024) return tflops diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index fbbfa6d0b..3b443c690 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -21,8 +21,8 @@ # SOFTWARE. import argparse -import subprocess import re +import subprocess FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 0934c8ea1..c78ccabbc 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -12,8 +12,8 @@ In this tutorial, you will write a simple vector addition using Triton and learn # Compute Kernel # -------------------------- -from triton.language.core import constexpr import torch + import triton import triton.language as tl @@ -38,7 +38,7 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case + # Load x and y from DRAM, masking out any extra elements in case # the input is not a multiple of the block size x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 2c0cfb9a8..30e507b0d 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -16,6 +16,8 @@ You will learn about: # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Let us consider instead the case of a simple (numerically stabilized) softmax operation: +import triton.language as tl +import triton import torch @@ -59,13 +61,10 @@ def naive_softmax(x): # power-of-two number of elements, so we need to internally "pad" each row and guard the # memory operations properly if we want to handle any possible input shapes: -import triton -import triton.language as tl - @triton.jit def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, + output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr ): # The rows of the softmax are independent, so we parallelize across those @@ -136,7 +135,7 @@ y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) print(torch.allclose(y_triton, y_torch)) -#%% +# %% # As expected, the results are identical. # %% @@ -187,5 +186,5 @@ benchmark.run(show_plots=True, print_data=True) # In the above plot, we can see that: # # - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. -# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. # Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape. diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 2d2ab91e9..240583df2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -112,13 +112,13 @@ You will specifically learn about: # # number of programs ids along the N axis # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # # number of programs in group -# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# num_pid_in_group = GROUP_SIZE_M * num_pid_n # # id of the group this program is in -# group_id = pid // num_pid_in_group +# group_id = pid // num_pid_in_group # # row-id of the first program in the group -# first_pid_m = group_id * GROUP_SIZE_M +# first_pid_m = group_id * GROUP_SIZE_M # # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller -# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # # *within groups*, programs are ordered in a column-major order # # row-id of the program in the *launch grid* # pid_m = first_pid_m + (pid % group_size_m) @@ -141,6 +141,7 @@ You will specifically learn about: # import torch + import triton import triton.language as tl @@ -152,18 +153,19 @@ import triton.language as tl # - An autotuning *key* whose change in values will trigger evaluation of all the # provided configs + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64 , 'BLOCK_SIZE_N': 32 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32 , 'BLOCK_SIZE_N': 64 , 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ], key=['M', 'N', 'K'], ) @@ -185,7 +187,7 @@ def matmul_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, - ): +): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ @@ -196,16 +198,16 @@ def matmul_kernel( pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction + # We will advance this pointer as we move in the K direction # and accumulate # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers @@ -213,8 +215,8 @@ def matmul_kernel( offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) - b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix @@ -223,8 +225,8 @@ def matmul_kernel( # `accumulator` will be converted back to fp16 after the loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. - # This means that if K is not a multiple of BLOCK_SIZE_K, + # Note that for simplicity, we don't apply a mask here. + # This means that if K is not a multiple of BLOCK_SIZE_K, # this will access out-of-bounds memory and produce an # error or (worse!) incorrect results. a = tl.load(a_ptrs) @@ -236,7 +238,7 @@ def matmul_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32 ! - if meta['ACTIVATION']: + if meta['ACTIVATION']: accumulator = meta['ACTIVATION'](accumulator) c = accumulator.to(tl.float16) diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index d988746a7..5c4f53435 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -13,7 +13,7 @@ whose state is generally composed of a bit mask tensor of the same shape as the # %% # Baseline # ------------- -# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance +# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance # of deep neural networks in low-data regime (i.e. regularization). # # It takes a vector as input and produces a vector of the same shape as output. Each scalar in the @@ -30,16 +30,18 @@ whose state is generally composed of a bit mask tensor of the same shape as the import tabulate import torch + import triton import triton.language as tl + @triton.jit def _dropout( - x_ptr, # pointer to the input - x_keep_ptr, # pointer to a mask of 0s and 1s - output_ptr, # pointer to the output - n_elements, # number of elements in the `x` tensor - p, # probability that an element of `x` is changed to zero + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero **meta, ): BLOCK_SIZE = meta['BLOCK_SIZE'] @@ -64,6 +66,7 @@ def dropout(x, x_keep, p): _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output + # Input tensor x = torch.randn(size=(10,)).cuda() # Dropout mask @@ -88,7 +91,7 @@ print(tabulate.tabulate([ # of persisting randomness across multiple invocations of the kernel. # # Pseudorandom number generation in Triton is simple! In this tutorial we will use the -# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` +# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` # values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides # other :ref:`random number generation strategies `. # @@ -97,6 +100,7 @@ print(tabulate.tabulate([ # # Let's put it all together. + @triton.jit def _seeded_dropout( x_ptr, diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index ffad17f50..82231e15c 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -4,15 +4,17 @@ Layer Normalization """ import torch -import triton.language as tl + import triton +import triton.language as tl + # Forward Pass @triton.jit def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # position of elements processed by this program - row = tl.program_id(0) + row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) mask = cols < N # offset data pointers to start at the row of interest @@ -24,9 +26,9 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): mean = tl.sum(x, axis=0) / N # compute std xmean = tl.where(mask, x - mean, 0.) - var = tl.sum(xmean * xmean, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - xhat = xmean*rstd + var = tl.sum(xmean * xmean, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + xhat = xmean * rstd # write-back mean/rstd tl.store(M + row, mean) tl.store(V + row, rstd) @@ -41,16 +43,16 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): # Backward pass (DX + partial DW + partial DB) @triton.jit def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, - stride, N, eps, - **META): + stride, N, eps, + **META): GROUP_SIZE_M = META['GROUP_SIZE_M'] BLOCK_SIZE_N = META['BLOCK_SIZE_N'] # position of elements processed by this program - row = tl.program_id(0) + row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) mask = cols < N # offset data pointers to start at the row of interest - X += row * stride + X += row * stride DY += row * stride DX += row * stride # offset locks and weight/bias gradient pointer @@ -59,28 +61,28 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, # these buffers stay in the L2, which allow this kernel # to be fast lock_id = row % GROUP_SIZE_M - Lock += lock_id - Count = Lock + GROUP_SIZE_M - DW = DW + lock_id*N + cols - DB = DB + lock_id*N + cols + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols # load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - w = tl.load(W + cols, mask=mask).to(tl.float32) - mean = tl.load(M + row) - rstd = tl.load(V + row) + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + mean = tl.load(M + row) + rstd = tl.load(V + row) # compute dx - xhat = (x - mean)*rstd - wdy = w * dy - xhat = tl.where(mask, xhat, 0.) - wdy = tl.where(mask, wdy , 0.) + xhat = (x - mean) * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat*mean1 + mean2))*rstd + dx = (wdy - (xhat * mean1 + mean2)) * rstd # write-back dx tl.store(DX + cols, dx, mask=mask) # accumulate partial sums for dw/db - partial_dw = (dy*xhat).to(w.dtype) + partial_dw = (dy * xhat).to(w.dtype) partial_db = (dy).to(w.dtype) while tl.atomic_cas(Lock, 0, 1) == 1: pass @@ -97,24 +99,27 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) + + @triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): pid = tl.program_id(0) BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] - cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None]*N + cols[None, :] + offs = rows[:, None] * N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) - tl.store(FINAL_DW + cols, sum_dw, mask=cols BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, + _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, + x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE - ctx.num_warps = num_warps - ctx.eps = eps + ctx.num_warps = num_warps + ctx.eps = eps return y @staticmethod @@ -154,11 +159,11 @@ class LayerNorm(torch.autograd.Function): if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output - locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda') + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB @@ -166,14 +171,14 @@ class LayerNorm(torch.autograd.Function): M, N = x_arg.shape _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, + BLOCK_SIZE_N=ctx.BLOCK_SIZE, GROUP_SIZE_M=GROUP_SIZE_M, num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M = 32, - BLOCK_SIZE_N = 128) + _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=128) return dx, None, dw, db, None @@ -184,10 +189,10 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') - dy = .1*torch.randn_like(x) + weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + dy = .1 * torch.randn_like(x) x.requires_grad_(True) # forward pass y_tri = layer_norm(x, w_shape, weight, bias, eps) @@ -205,6 +210,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1) triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1) + @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], @@ -218,14 +224,14 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} ) ) -def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'): +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) - x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda') - dy = .1*torch.randn_like(x) + weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + dy = .1 * torch.randn_like(x) x.requires_grad_(True) # utility functions if provider == 'triton': @@ -238,14 +244,15 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cu y_fwd = lambda: apex_layer_norm(x) # forward pass if mode == 'forward': - gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6 + gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500) # backward pass if mode == 'backward': - gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6 + gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 y = y_fwd() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) + bench_layer_norm.run(save_path='.', print_data=True) From 9801aa7b567012610b7467584483bec340d7b429 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 7 Jan 2022 12:34:38 -0800 Subject: [PATCH 042/215] [DOCS] fix tutorials for v2.0 (#422) - Fix meta-parameter usage on tutorials. - Install tutorial dependencies on CI. - Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies. - Make some performance tests deterministic. --- .github/workflows/documentation.yml | 4 +-- .github/workflows/integration-tests.yml | 3 +- docs/getting-started/installation.rst | 2 +- python/requirements-test.txt | 3 -- python/setup.py | 18 +++++++++++- python/test/regression/test_performance.py | 2 ++ python/tutorials/02-fused-softmax.py | 2 +- python/tutorials/03-matrix-multiplication.py | 6 ++-- python/tutorials/04-low-memory-dropout.py | 6 ++-- python/tutorials/05-layer-norm.py | 31 +++++++++++--------- python/tutorials/README.rst | 9 +++++- 11 files changed, 54 insertions(+), 32 deletions(-) delete mode 100644 python/requirements-test.txt diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 695dfd1e3..d4ba42733 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -25,7 +25,7 @@ jobs: run: | alias python='python3' cd python - pip3 install -e . + pip3 install -e '.[tutorials]' - name: Build docs run: | @@ -39,4 +39,4 @@ jobs: eval `ssh-agent -s` DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }} git remote set-url origin git@github.com:openai/triton.git - git push \ No newline at end of file + git push diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 987b346a3..d99e95dc7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -28,7 +28,7 @@ jobs: run: | alias python='python3' cd python - pip3 install -e . + pip3 install -e '.[tests]' - name: Unit tests run: | @@ -44,4 +44,3 @@ jobs: pytest -vs . sudo nvidia-smi -i 0 -rgc sudo nvidia-smi -i 0 -rmc - diff --git a/docs/getting-started/installation.rst b/docs/getting-started/installation.rst index db6b6261b..20c4628bc 100644 --- a/docs/getting-started/installation.rst +++ b/docs/getting-started/installation.rst @@ -44,7 +44,7 @@ You can then test your installation by running the unit tests: .. code-block:: bash - pip install -r requirements-test.txt + pip install -e '.[tests]' pytest -vs test/unit/ and the benchmarks diff --git a/python/requirements-test.txt b/python/requirements-test.txt deleted file mode 100644 index 84893a889..000000000 --- a/python/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy -pytest -scipy >= 1.7.1 diff --git a/python/setup.py b/python/setup.py index 28194f41e..1171ad0a8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -126,7 +126,11 @@ setup( description="A language and compiler for custom Deep Learning operations", long_description="", packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], - install_requires=["cmake", "torch", "filelock"], + install_requires=[ + "cmake", + "filelock", + "torch", + ], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], @@ -142,4 +146,16 @@ setup( "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", ], + extras_require={ + "tests": [ + "numpy", + "pytest", + "scipy>=1.7.1", + ], + "tutorials": [ + "matplotlib", + "pandas", + "tabulate", + ], + }, ) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 012ff65d7..f6e7ec237 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -54,6 +54,7 @@ matmul_data = { @pytest.mark.parametrize('M, N, K', matmul_data.keys()) def test_matmul(M, N, K): + torch.manual_seed(0) ref_gpu_util = matmul_data[(M, N, K)]['v100'] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 @@ -99,6 +100,7 @@ elementwise_data = { @pytest.mark.parametrize('N', elementwise_data.keys()) def test_elementwise(N): + torch.manual_seed(0) ref_gpu_util = elementwise_data[N]['v100'] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 30e507b0d..e5559ca7f 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -133,7 +133,7 @@ torch.manual_seed(0) x = torch.randn(1823, 781, device='cuda') y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) -print(torch.allclose(y_triton, y_torch)) +assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) # %% # As expected, the results are identical. diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 240583df2..ddfe9c0bc 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -237,9 +237,9 @@ def matmul_kernel( a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here - # while the accumulator is still in FP32 ! - if meta['ACTIVATION']: - accumulator = meta['ACTIVATION'](accumulator) + # while the accumulator is still in FP32! + if ACTIVATION: + accumulator = ACTIVATION(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index 5c4f53435..cf172537a 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -42,9 +42,8 @@ def _dropout( output_ptr, # pointer to the output n_elements, # number of elements in the `x` tensor p, # probability that an element of `x` is changed to zero - **meta, + BLOCK_SIZE: tl.constexpr, ): - BLOCK_SIZE = meta['BLOCK_SIZE'] pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -108,10 +107,9 @@ def _seeded_dropout( n_elements, p, seed, - **meta, + BLOCK_SIZE: tl.constexpr, ): # compute memory offsets of elements handled by this instance - BLOCK_SIZE = meta['BLOCK_SIZE'] pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 82231e15c..1cefc60b9 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -8,11 +8,19 @@ import torch import triton import triton.language as tl +try: + # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it + # should not be added to extras_require in setup.py. + import apex + HAS_APEX = True +except ModuleNotFoundError: + HAS_APEX = False + # Forward Pass @triton.jit -def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] +def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, + BLOCK_SIZE: tl.constexpr): # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) @@ -42,11 +50,8 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): # Backward pass (DX + partial DW + partial DB) @triton.jit -def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, - stride, N, eps, - **META): - GROUP_SIZE_M = META['GROUP_SIZE_M'] - BLOCK_SIZE_N = META['BLOCK_SIZE_N'] +def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps, + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) @@ -102,15 +107,14 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, @triton.jit -def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): +def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): pid = tl.program_id(0) - BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] - BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) + rows = i + tl.arange(0, BLOCK_SIZE_M) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None] * N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) @@ -216,8 +220,8 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): x_names=['N'], x_vals=[512 * i for i in range(2, 32)], line_arg='provider', - line_vals=['triton', 'torch', 'apex'], - line_names=['Triton', 'Torch', 'Apex'], + line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), + line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', @@ -239,7 +243,6 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c if provider == 'torch': y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) if provider == 'apex': - import apex apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) y_fwd = lambda: apex_layer_norm(x) # forward pass diff --git a/python/tutorials/README.rst b/python/tutorials/README.rst index 24c752842..a36a08bbe 100644 --- a/python/tutorials/README.rst +++ b/python/tutorials/README.rst @@ -1,4 +1,11 @@ Tutorials ================== -Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. \ No newline at end of file +Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. + +To install the dependencies for the tutorials: + +.. code-block:: bash + + cd triton + pip install -e './python[tutorials]' From a70acfec771d5e1ad9b4df8baa50166e99954e32 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 7 Jan 2022 13:11:34 -0800 Subject: [PATCH 043/215] [STYLE] add isort and autopep8 config files and check on CI (#423) Also a fix a few more style issues from the "aggressive" mode of autopep8. --- .github/workflows/integration-tests.yml | 6 + .isort.cfg | 4 + python/setup.cfg | 3 + python/setup.py | 4 +- python/test/regression/test_performance.py | 2 +- python/triton/__init__.py | 2 +- python/triton/code_gen.py | 2 +- python/triton/language/core.py | 2 +- python/triton/ops/blocksparse/matmul.py | 141 +++++++++++---------- python/triton/tools/disasm.py | 8 +- python/tutorials/02-fused-softmax.py | 5 +- 11 files changed, 102 insertions(+), 77 deletions(-) create mode 100644 .isort.cfg diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d99e95dc7..c01a16de1 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -30,6 +30,12 @@ jobs: cd python pip3 install -e '.[tests]' + - name: Check imports + run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )" + + - name: Check style + run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )" + - name: Unit tests run: | cd python/test/unit diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..833801cca --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,4 @@ +[settings] +known_local_folder=triton +line_length=88 +py_version=36 diff --git a/python/setup.cfg b/python/setup.cfg index 08aedd7e6..9d24c7de7 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -1,2 +1,5 @@ [metadata] description_file = README.md + +[pycodestyle] +ignore = E501,E701,E731 diff --git a/python/setup.py b/python/setup.py index 1171ad0a8..1cc2ea103 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,7 +94,7 @@ class CMakeBuild(build_ext): "-DBUILD_PYTHON_MODULE=ON", "-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, "-DLLVM_LIBRARY_DIR=" + llvm_library_dir, - #'-DPYTHON_EXECUTABLE=' + sys.executable, + # '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) @@ -148,6 +148,8 @@ setup( ], extras_require={ "tests": [ + "autopep8", + "isort", "numpy", "pytest", "scipy>=1.7.1", diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index f6e7ec237..84e829aa8 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,4 +1,3 @@ -import triton.language as tl import subprocess import sys @@ -7,6 +6,7 @@ import torch from numpy import record import triton +import triton.language as tl ####################### # Utilities diff --git a/python/triton/__init__.py b/python/triton/__init__.py index c079880e9..b4a92a8f8 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,4 +1,4 @@ -# version +"""isort:skip_file""" __version__ = '2.0.0' # TODO: torch needs to be imported first diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 439c1798e..af95bf280 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -852,7 +852,7 @@ class Autotuner: else: config = self.configs[0] self.best_config = config - if config.pre_hook != None: + if config.pre_hook is not None: config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6895c101c..4f63b33bc 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -293,7 +293,7 @@ class block: dst_shape = [] curr = 0 for sl in slices: - if sl == None: + if sl is None: dst_shape.append(1) elif sl == slice(None, None, None): dst_shape.append(src_shape[curr]) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 15e6c0523..9a04ded66 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -26,9 +26,9 @@ def _sdd_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, BLOCK: tl.constexpr, EVEN_K: tl.constexpr ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # block_id = tl.program_id(1) + grid_offset lut += block_id * 3 # offsets @@ -39,21 +39,23 @@ def _sdd_kernel( start_am = tl.load(lut + 1) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_ak = tl.arange(0, TILE_K) - a_ptrs = A + (off_z * stride_za - + off_h * stride_ha - + offs_am[:, None] * stride_ma - + offs_ak[None, :] * stride_ak) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak # initialize pointers to B start_bn = tl.load(lut + 2) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bk = tl.arange(0, TILE_K) - b_ptrs = B + (off_z * stride_zb - + off_h * stride_hb - + offs_bn[None, :] * stride_nb - + offs_bk[:, None] * stride_bk) - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(K, 0, -TILE_K): if EVEN_K: @@ -66,15 +68,16 @@ def _sdd_kernel( a_ptrs += TILE_K * stride_ak b_ptrs += TILE_K * stride_bk c = acc.to(C.dtype.element_ty) - ## ---------------- ## - ## Epilogue ## - ## ---------------- ## + # ---------------- # + # Epilogue # + # ---------------- # offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK - pc = C + (off_z * stride_zc - + block_id * stride_hc - + offs_cm[:, None] * stride_mc - + offs_cn[None, :] * stride_nc) + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc tl.store(pc, c, mask=True) @@ -134,9 +137,9 @@ def _dsd_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # pid_m = tl.program_id(0) pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) @@ -168,9 +171,9 @@ def _dsd_kernel( + off_h * stride_hb \ + offs_bn[None, :] * stride_bn \ + offs_bk[:, None] * stride_bk - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) pinc += 2 inc_a = tl.load(pinc + 1) @@ -192,7 +195,8 @@ def _dsd_kernel( # initialize pointers to C offs_cm = column * TILE_M + tl.arange(0, TILE_M) offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) - pc = C + off_h * stride_hc \ + pc = C \ + + off_h * stride_hc \ + pidz * stride_zc \ + offs_cm[:, None] * stride_cm \ + offs_cn[None, :] * stride_cn @@ -224,24 +228,24 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N TILE_N = 128 # compute output grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] - # fmt: off _dsd_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), BS3, AS1, lut, - TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, num_warps=4, GROUP_SIZE_M=4, ) # exit() return c + def dsd_lut(layout, block, step, trans, device): sizes = torch.sum(layout, 2 if trans else 1) head_id, col_id = sizes.nonzero(as_tuple=True) sizes = sizes.flatten() - segments = sizes*step + segments = sizes * step # pointer increments if trans: nnz = layout.nonzero(as_tuple=False) @@ -301,13 +305,13 @@ def dsd_lut(layout, block, step, trans, device): A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] A_incs = A_incs.view(-1) # create header - width = col_id.size(0) - offsets = offsets*2*div + 4*width - segments = segments*div - header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() # create increments - incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) # create lut lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device) @@ -317,6 +321,8 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- + + @triton.jit def _dds_kernel( A, B, C, @@ -327,9 +333,9 @@ def _dds_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, ): - #------------# - #- Prologue -# - #------------# + # ------------ # + # - Prologue - # + # ------------ # pid_m = tl.program_id(0) pid_n = tl.program_id(1) num_pid_m = tl.num_programs(0) @@ -343,31 +349,31 @@ def _dds_kernel( off_h = tl.load(header + 3) pinc = lut + offset # initialize pointers to A (dense) - offs_am = pid_m*TILE_M + tl.arange(0, TILE_M) + offs_am = pid_m * TILE_M + tl.arange(0, TILE_M) offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) start_ak = tl.load(pinc) start_ak = tl.multiple_of(start_ak, 8) offs_ak = start_ak + tl.arange(0, TILE_K) ptrs_a = A + pid_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ka + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ka # initialize pointers to B (sparse) block_id = tl.load(pinc + 1) block_id = tl.multiple_of(block_id, 8) offs_bn = tl.arange(0, TILE_N) offs_bk = tl.arange(0, TILE_K) ptrs_b = B + pid_z * stride_zb \ - + block_id * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk - ## ---------------- ## - ## Inner Loop ## - ## ---------------- ## + + block_id * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) for k in range(AS1, 0, -TILE_K): - a = tl.load(ptrs_a, mask = offs_am[:, None] < DS0) - b = tl.load(ptrs_b, mask = True) + a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0) + b = tl.load(ptrs_b, mask=True) acc += tl.dot(a, b) pinc += 2 inc_a = tl.load(pinc) @@ -377,21 +383,22 @@ def _dds_kernel( inc_a = inc_a * stride_ka ptrs_a += inc_a ptrs_b += inc_b - ## ---------------- ## - ## Epilogue ## - ## ---------------- ## + # ---------------- # + # Epilogue # + # ---------------- # c = acc.to(C.dtype.element_ty) # initialize pointers to C (dense) offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) offs_cn = column * TILE_N + tl.arange(0, TILE_N) ptrs_c = C + off_h * stride_hc \ - + pid_z * stride_zc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc + + pid_z * stride_zc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc # write back - tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) + tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0) -def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None): + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): if a.stride(2) != 1 and a.stride(3) != 1: a = a.contiguous() if b.stride(2) != 1 and b.stride(3) != 1: @@ -414,14 +421,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = c = out TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] - # fmt: off _dds_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), AS2, BS2, lut, - TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4, + TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4, num_warps=4, GROUP_SIZE_M=4, ) return c @@ -429,6 +435,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = ############## # MAIN API # ############## + + class _matmul(torch.autograd.Function): fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} @@ -474,8 +482,9 @@ class _matmul(torch.autograd.Function): ) dout = dc if ctx.has_out else None return da, db, None, None, None,\ - None, None, None, None,\ - None, None, None, None, None, dout + None, None, None, None,\ + None, None, None, None, None, dout + class matmul: @@ -499,11 +508,11 @@ class matmul: self.da_lut, self.da_width = sdd_lut(layout, block, device) self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) if self.mode == 'dds': - self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) self.db_lut, self.db_width = sdd_lut(layout, block, device) - def __call__(self, a, b, out = None): + def __call__(self, a, b, out=None): c = _matmul.apply( a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, self.c_lut, self.c_width, diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 3b443c690..b030e72ec 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -52,7 +52,7 @@ def processSassLines(fline, sline, labels): asm = asm[:-2] + ";" ctrl = parseCtrl(sline) # BRA target address - if BRA_RE.match(asm) != None: + if BRA_RE.match(asm) is not None: target = int(BRA_RE.match(asm).group(2), 16) if target in labels: pass @@ -62,7 +62,7 @@ def processSassLines(fline, sline, labels): def extract(file_path, fun): - if fun == None: + if fun is None: sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) else: sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) @@ -77,7 +77,7 @@ def extract(file_path, fun): # /*0x...*/ fname_match = FNAME_RE.match(line) # Looking for new function header (function: ) - while FNAME_RE.match(line) == None: + while FNAME_RE.match(line) is None: line_idx += 1 if line_idx < len(sass_lines): line = sass_lines[line_idx].decode() @@ -94,7 +94,7 @@ def extract(file_path, fun): # store sass asm in buffer and them print them (for labels) # (ctrl, asm) asm_buffer = [] - while FLINE_RE.match(line) != None: + while FLINE_RE.match(line) is not None: # First line (Offset ASM Encoding) fline = sass_lines[line_idx].decode() line_idx += 1 diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index e5559ca7f..7af24e18d 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -16,10 +16,11 @@ You will learn about: # Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. # Let us consider instead the case of a simple (numerically stabilized) softmax operation: -import triton.language as tl -import triton import torch +import triton +import triton.language as tl + @torch.jit.script def naive_softmax(x): From efdabe6073f622ab8c065159ee01d908d3ac74b3 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 7 Jan 2022 15:28:36 -0800 Subject: [PATCH 044/215] [STYLE] check python with flake8 (#424) I've been using this locally to find errors without running tests, and now that we're using autopep8, it passes with minimal suppressions. This is also what turned up the issues with the tutorials, which were fixed in #422. --- .github/workflows/integration-tests.yml | 3 +++ python/bench/bench_matmul.py | 1 - python/setup.cfg | 3 +++ python/setup.py | 1 + python/test/regression/test_performance.py | 1 - python/test/unit/language/test_core.py | 1 + python/test/unit/language/test_random.py | 1 - python/triton/__init__.py | 1 + python/triton/code_gen.py | 6 ++---- python/triton/language/__init__.py | 1 + python/triton/language/core.py | 8 -------- python/triton/ops/__init__.py | 1 + python/triton/ops/blocksparse/__init__.py | 1 + python/triton/ops/cross_entropy.py | 6 +----- python/triton/ops/matmul.py | 2 +- python/triton/tools/disasm.py | 3 +-- python/tutorials/01-vector-add.py | 2 +- 17 files changed, 18 insertions(+), 24 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c01a16de1..45798e628 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -36,6 +36,9 @@ jobs: - name: Check style run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )" + - name: Flake8 + run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )" + - name: Unit tests run: | cd python/test/unit diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 9db005da0..b776b3dbf 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -50,7 +50,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): a = a.t() if BT: b = b.t() - num_flops = 2 * M * N * K tflops = lambda ms: 2. * M * N * K / ms * 1e-9 if provider == "cublas": ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) diff --git a/python/setup.cfg b/python/setup.cfg index 9d24c7de7..9af1cf69c 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -3,3 +3,6 @@ description_file = README.md [pycodestyle] ignore = E501,E701,E731 + +[flake8] +ignore = E501,E701,E731 diff --git a/python/setup.py b/python/setup.py index 1cc2ea103..db22c14af 100644 --- a/python/setup.py +++ b/python/setup.py @@ -149,6 +149,7 @@ setup( extras_require={ "tests": [ "autopep8", + "flake8", "isort", "numpy", "pytest", diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 84e829aa8..39299a89a 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -3,7 +3,6 @@ import sys import pytest import torch -from numpy import record import triton import triton.language as tl diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7f0af78b4..d8e88a609 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,3 +1,4 @@ +# flake8: noqa: F821,F841 import copy import itertools import re diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 82ae7f0c2..042065403 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -2,7 +2,6 @@ import numpy as np import pytest import scipy.stats import torch -from numpy.random import Philox import triton import triton.language as tl diff --git a/python/triton/__init__.py b/python/triton/__init__.py index b4a92a8f8..f9982939c 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,4 +1,5 @@ """isort:skip_file""" +# flake8: noqa: F401 __version__ = '2.0.0' # TODO: torch needs to be imported first diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index af95bf280..b7da2047e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,19 +1,17 @@ import ast import builtins -import dbm import functools import hashlib import inspect import os import pickle -import struct import subprocess import sys import tempfile import textwrap import time import warnings -from typing import Dict, Optional +from typing import Dict import torch from filelock import FileLock @@ -406,7 +404,7 @@ class CodeGenerator(ast.NodeVisitor): self.visit(pos_cond_node), self.visit(neg_cond_node), _builder=self.builder) - #cond_node = neg_cond_node + # cond_node = neg_cond_node step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index a7f341f16..0b04465eb 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa: F401 from . import core, random from .core import * from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 4f63b33bc..d32da45c3 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -802,14 +802,6 @@ def max_contiguous(input, value, _builder=None): return frontend.max_contiguous(input, value, _builder) -@builtin -def max_contiguous(input, value, _builder=None): - """ - Let the compiler knows that the `value` first values in :code:`input` are contiguous. - """ - return frontend.max_contiguous(input, value, _builder) - - # ----------------------- # Standard library # ----------------------- diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index 7d27ffd20..dcaed8ccf 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa: F401 #from .conv import _conv, conv from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy diff --git a/python/triton/ops/blocksparse/__init__.py b/python/triton/ops/blocksparse/__init__.py index 231c27a1f..df3353e12 100644 --- a/python/triton/ops/blocksparse/__init__.py +++ b/python/triton/ops/blocksparse/__init__.py @@ -1,2 +1,3 @@ +# flake8: noqa: F401 from .matmul import matmul from .softmax import softmax diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index dfd4f4487..910417d2c 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -1,5 +1,3 @@ -import os - import torch import triton @@ -96,11 +94,9 @@ class _cross_entropy(torch.autograd.Function): """ # load saved tensors neg_logprobs, indices = ctx.saved_tensors - # make kernel - device, dtype = neg_logprobs.device, neg_logprobs.dtype - n_cols = neg_logprobs.shape[-1] # run the kernel # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] grid = lambda opt: (neg_logprobs.numel() // n_cols, ) _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) return neg_logprobs, None diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 60ecc9f3b..d7af57406 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from .matmul_perf_model import * +from .matmul_perf_model import estimate_matmul_time, prune_num_stages def init_to_zero(name): diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index b030e72ec..3672d4b05 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import argparse import re import subprocess @@ -75,7 +74,7 @@ def extract(file_path, fun): # .headerflags: ... # /*0000*/ asmstr /*0x...*/ # /*0x...*/ - fname_match = FNAME_RE.match(line) + # Looking for new function header (function: ) while FNAME_RE.match(line) is None: line_idx += 1 diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index c78ccabbc..ca8b463fe 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -65,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output From 94a2e10fe57015375b6b264c67d0453fe7d10789 Mon Sep 17 00:00:00 2001 From: daadaada Date: Wed, 12 Jan 2022 02:20:31 +0800 Subject: [PATCH 045/215] [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) --- include/triton/codegen/analysis/layout.h | 77 ++- include/triton/ir/builder.h | 2 +- include/triton/ir/dispatch.h | 2 +- include/triton/ir/instructions.h | 25 +- lib/codegen/analysis/layout.cc | 103 +++- lib/codegen/analysis/swizzle.cc | 9 +- lib/codegen/pass.cc | 1 - lib/codegen/selection/common.h | 78 +++ lib/codegen/selection/generator.cc | 607 +++++++++++++++-------- lib/codegen/transform/peephole.cc | 2 +- lib/codegen/transform/prefetch.cc | 9 +- lib/driver/llvm.cc | 1 + lib/ir/builder.cc | 4 +- lib/ir/dispatch.cc | 5 +- lib/ir/instructions.cc | 23 +- python/test/unit/language/test_core.py | 24 +- python/triton/language/core.py | 8 +- 17 files changed, 717 insertions(+), 263 deletions(-) create mode 100644 lib/codegen/selection/common.h diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 4d12e34c0..dc5150f05 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -109,6 +109,63 @@ protected: }; class mma_layout: public distributed_layout { +public: + enum TensorCoreType : uint8_t { + // floating-point tensor core instr + FP32_FP16_FP16_FP32 = 0, // default + FP32_BF16_BF16_FP32, + FP32_TF32_TF32_FP32, + // integer tensor core instr + INT32_INT1_INT1_INT32, // Not implemented + INT32_INT4_INT4_INT32, // Not implemented + INT32_INT8_INT8_INT32, // Not implemented + // + NOT_APPLICABLE, + }; + + // Used on nvidia GPUs with sm >= 80 + inline static const std::map> mma_instr_shape_ = { + {FP32_FP16_FP16_FP32, {16, 8, 16}}, + {FP32_BF16_BF16_FP32, {16, 8, 16}}, + {FP32_TF32_TF32_FP32, {16, 8, 8}}, + + {INT32_INT1_INT1_INT32, {16, 8, 256}}, + {INT32_INT4_INT4_INT32, {16, 8, 64}}, + {INT32_INT8_INT8_INT32, {16, 8, 32}}, + }; + + // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) + inline static const std::map> mma_mat_shape_ = { + {FP32_FP16_FP16_FP32, {8, 8, 8}}, + {FP32_BF16_BF16_FP32, {8, 8, 8}}, + {FP32_TF32_TF32_FP32, {8, 8, 4}}, + + {INT32_INT1_INT1_INT32, {8, 8, 64}}, + {INT32_INT4_INT4_INT32, {8, 8, 32}}, + {INT32_INT8_INT8_INT32, {8, 8, 16}}, + }; + + inline static const std::map mma_instr_ptx_ = { + {FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, + {FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, + {FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, + + {INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, + {INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, + {INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, + }; + + // vector length per ldmatrix (16*8/elelment_size_in_bits) + inline static const std::map mma_instr_vec_ = { + {FP32_FP16_FP16_FP32, 8}, + {FP32_BF16_BF16_FP32, 8}, + {FP32_TF32_TF32_FP32, 4}, + + {INT32_INT1_INT1_INT32, 128}, + {INT32_INT4_INT4_INT32, 32}, + {INT32_INT8_INT8_INT32, 16}, + }; + public: mma_layout(size_t num_warps, const std::vector& axes, @@ -116,7 +173,8 @@ public: const std::vector &values, analysis::align* align, target *tgt, shared_layout* layout_a, - shared_layout* layout_b); + shared_layout* layout_b, + ir::value *dot); void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } // accessor int fpw(size_t k) { return fpw_.at(k); } @@ -124,6 +182,16 @@ public: int spw(size_t k) { return spw_.at(k); } int rep(size_t k) { return rep_.at(k); } + // helpers for generator.cc + std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } + std::vector get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); } + std::vector get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); } + int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); } + int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); } + + // setter + void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; } + private: // fragment per warp std::vector fpw_; @@ -135,6 +203,8 @@ private: std::vector spt_; // repetitions std::vector rep_; + + TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; }; struct scanline_layout: public distributed_layout { @@ -182,7 +252,7 @@ public: const std::vector& shapes, const std::vector &values_, ir::type *ty, - analysis::align* align); + analysis::align* align, target *tgt); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } // accessors size_t get_size() { return size_; } @@ -197,6 +267,7 @@ public: ir::value* hmma_dot_b() { return hmma_dot_b_; } void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } int get_mma_vec() { return mma_vec_;} + int get_mma_strided() { return mma_strided_; } data_layout* get_arg_layout() { return arg_layout_; } private: @@ -209,6 +280,8 @@ private: ir::value* hmma_dot_b_; data_layout* arg_layout_; int mma_vec_; + int mma_strided_; + target *tgt_; }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 3a4094123..67ab47c90 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -154,7 +154,7 @@ public: value *create_cos(value* arg); value *create_sin(value* arg); value *create_log(value* arg); - value *create_dot(value *A, value *B, value *C); + value *create_dot(value *A, value *B, value *C, bool allow_tf32); value *create_trans(value *A, const std::vector &perm = {}); value *create_sqrt(value *A); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index c90480f1e..d8293d231 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -80,7 +80,7 @@ struct dispatch{ static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); // linear algebra - static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); + static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); // indexing static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 7c147f634..ca1416f48 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -742,26 +742,29 @@ public: }; private: - dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); + dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } - - bool is_prefetched_ = false; - DataType C_type_ = DataType::FP32; - DataType A_type_ = DataType::FP16; - DataType B_type_ = DataType::FP16; public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } + bool allow_tf32() const { return allow_tf32_; } public: - static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(dot_inst) _TRITON_DEFINE_ACCEPT(dot_inst) + +private: + bool is_prefetched_ = false; + bool allow_tf32_ = false; + DataType C_type_ = DataType::FP32; + DataType A_type_ = DataType::FP16; + DataType B_type_ = DataType::FP16; }; //class outer_inst: public builtin_inst { diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 64163c91c..d00959e45 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -23,19 +23,65 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) { return std::min(std::max(x, lo), hi); } -inline bool is_hmma_c(ir::value *v){ +inline bool is_hmma_c(ir::value *v, int sm){ bool result = false; if(auto *x = dynamic_cast(v)){ ir::value *a = x->get_operand(0); ir::type *a_ty = a->get_type(); ir::value *b = x->get_operand(1); ir::type *b_ty = b->get_type(); - result = a_ty->get_scalar_ty()->is_fp16_ty() && - b_ty->get_scalar_ty()->is_fp16_ty(); + result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || + (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || + (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && + x->allow_tf32() && sm >= 80); } return result; } +static mma_layout::TensorCoreType get_mma_type(ir::value *v) { + mma_layout::TensorCoreType mma_type; + if (auto* dot = dynamic_cast(v)) { + ir::value* a = dot->get_operand(0); + ir::value* b = dot->get_operand(1); + ir::type* a_ty = a->get_type(); + ir::type* b_ty = b->get_type(); + ir::type* c_ty = v->get_type(); + + if (c_ty->get_scalar_ty()->is_fp32_ty()) { + // floating point tensor cores + if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) { + mma_type = mma_layout::FP32_FP16_FP16_FP32; + return mma_type; + } + if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) { + mma_type = mma_layout::FP32_BF16_BF16_FP32; + return mma_type; + } + if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() + && dot->allow_tf32()) { + mma_type = mma_layout::FP32_TF32_TF32_FP32; + return mma_type; + } + } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { + throw std::runtime_error("integer tensor cores are not yet supported"); + // // integer tensor cores + // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { + // mma_type = mma_layout::INT32_INT1_INT1_INT32; + // return mma_type; + // } + // if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) { + // mma_type = mma_layout::INT32_INT4_INT4_INT32; + // return mma_type; + // } + // if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { + // mma_type = mma_layout::INT32_INT8_INT8_INT32; + // return mma_type; + // } + } + } + return mma_layout::NOT_APPLICABLE; +} + inline void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); @@ -52,11 +98,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { } } -inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { +inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); - if(i && is_hmma_c(i) && i->get_operand(n) == v) + if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) { result = i; + } } } @@ -142,7 +189,9 @@ mma_layout::mma_layout(size_t num_warps, const std::vector& shape, const std::vector &values, analysis::align* align, target* tgt, - shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { + shared_layout *layout_a, shared_layout *layout_b, + ir::value *dot): distributed_layout(MMA, axes, shape, values, align) { + tensor_core_type_ = get_mma_type(dot); /* fragments per warp */ // try to make things as square as possible to maximize data re-use if(tgt->as_nvidia()->sm() < 80){ @@ -159,9 +208,9 @@ mma_layout::mma_layout(size_t num_warps, spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; } else{ - fpw_ = {1, 1, 1}; - spw_ = {16, 8, 1}; - rep_ = {2, 2, 1}; + // fpw_ = {1, 1, 1}; + spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 + // rep_ = {2, 2, 1}; } order_ = {0, 1}; @@ -356,7 +405,8 @@ shared_layout::shared_layout(data_layout *arg, const std::vector& shape, const std::vector &values, ir::type *ty, - analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { + analysis::align* align, target *tgt) + : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) { size_ = 0; arg_layout_ = arg; @@ -382,12 +432,25 @@ shared_layout::shared_layout(data_layout *arg, for(ir::value* v: values){ extract_dot_use(v, dot_a, 0); extract_dot_use(v, dot_b, 1); - extract_hmma_dot_use(v, hmma_dot_a, 0); - extract_hmma_dot_use(v, hmma_dot_b, 1); + extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm()); + extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm()); } hmma_dot_a_ = hmma_dot_a; hmma_dot_b_ = hmma_dot_b; + // Update mma_vec + if (hmma_dot_a_) { + assert(order_.size() == 2); + std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); + mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m + mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; + } else if (hmma_dot_b_) { + assert(order_.size() == 2); + std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); + mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k + mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; + } + // size size_ = ty_->get_primitive_size_in_bits() / 8; for(auto s: shape_) @@ -451,7 +514,8 @@ void layouts::make_graph(ir::instruction *i) { void layouts::create(size_t id, const std::vector& values) { // if(layouts_.find(id) != layouts_.end()) // return; - auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); + auto it_hmma_c = std::find_if(values.begin(), values.end(), + [&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); }); auto cmp = [](ir::value* x, ir::value *y) { std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; @@ -473,13 +537,16 @@ void layouts::create(size_t id, const std::vector& values) { ir::value *b = dot->get_operand(1); create(groups_.at(a), values_.at(groups_.at(a))); create(groups_.at(b), values_.at(groups_.at(b))); - layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b))); + layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, + (shared_layout*)layouts_.at(groups_.at(a)), + (shared_layout*)layouts_.at(groups_.at(b)), + dot); } else if(it_cts != values.end()){ ir::instruction *cts = (ir::instruction*)*it_cts; ir::value *arg = cts->get_operand(0); create(groups_.at(arg), values_.at(groups_.at(arg))); - layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_); } else{ layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); @@ -516,7 +583,7 @@ void layouts::run(ir::module &mod) { scanline_layout *layout = get(arg)->to_scanline(); shapes[axis] = layout->mts(axis); // create layout - layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; } if(auto *val = dynamic_cast(i)){ @@ -529,12 +596,12 @@ void layouts::run(ir::module &mod) { shape[k] = std::max(in_layout->shape_per_cta(k), out_layout->shape_per_cta(k)); } - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); tmp_[val] = id; } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_); tmp_[atom] = id; } }); diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index fcde938d9..1dbae10d4 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -19,6 +19,7 @@ void swizzle::run(ir::module &) { continue; ir::value* mma_dot_a = layout->hmma_dot_a(); ir::value* mma_dot_b = layout->hmma_dot_b(); + if(!mma_dot_a && !mma_dot_b){ per_phase_[layout] = 1; max_phase_[layout] = 1; @@ -39,10 +40,10 @@ void swizzle::run(ir::module &) { else vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } - else{ - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = 8 / per_phase_[layout]; - vec_[layout] = 8; + else { + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; + vec_[layout] = layout->get_mma_vec(); } } } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index d38d81a9c..8921d6c84 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -85,7 +85,6 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); - // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/common.h b/lib/codegen/selection/common.h new file mode 100644 index 000000000..c4b0951da --- /dev/null +++ b/lib/codegen/selection/common.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include "triton/codegen/selection/generator.h" +#include "triton/codegen/target.h" +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/swizzle.h" +#include "triton/codegen/transform/coalesce.h" +#include "triton/ir/context.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/type.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +namespace triton::codegen { +// types +#define void_ty builder_->getVoidTy() +#define f16_ty builder_->getHalfTy() +#define bf16_ty builder_->getBFloatTy() +#define f32_ty builder_->getFloatTy() +#define i8_ty builder_->getInt8Ty() +#define i32_ty builder_->getInt32Ty() +#define vec_ty(type, num_el) VectorType::get(type, num_el, false) +#define ptr_ty(...) PointerType::get(__VA_ARGS__) +// constants +#define i32(...) builder_->getInt32(__VA_ARGS__) +// ops +#define and_(...) builder_->CreateAnd(__VA_ARGS__) +#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) +#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) +#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) +#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) +#define br(...) builder_->CreateBr(__VA_ARGS__) +#define call(...) builder_->CreateCall(__VA_ARGS__) +#define cast(...) builder_->CreateCast(__VA_ARGS__) +#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) +#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) +#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) +#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) +#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) +#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) +#define fmul(...) builder_->CreateFMul(__VA_ARGS__) +#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) +#define fsub(...) builder_->CreateFSub(__VA_ARGS__) +#define icmp(...) builder_->CreateICmp(__VA_ARGS__) +#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) +#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) +#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) +#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) +#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) +#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) +#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define lshr(...) builder_->CreateLShr(__VA_ARGS__) +#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) +#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) +#define neg(...) builder_->CreateNeg(__VA_ARGS__) +#define phi(...) builder_->CreatePHI(__VA_ARGS__) +#define ret(...) builder_->CreateRet(__VA_ARGS__) +#define select(...) builder_->CreateSelect(__VA_ARGS__) +#define store(...) builder_->CreateStore(__VA_ARGS__) +#define sub(...) builder_->CreateSub(__VA_ARGS__) +#define shl(...) builder_->CreateShl(__VA_ARGS__) +#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) +#define urem(...) builder_->CreateURem(__VA_ARGS__) +#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) +#define xor_(...) builder_->CreateXor(__VA_ARGS__) + +} // namespace triton::codegen \ No newline at end of file diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a6148b2d1..b180ecb12 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // return (*builder_)->CreateGEP(ty, ptr, vals, name); //} - // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() +#define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() #define i8_ty builder_->getInt8Ty() +#define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) @@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) #define xor_(...) builder_->CreateXor(__VA_ARGS__) - /** * \brief Convert Triton-IR Type to LLVM-IR Type */ @@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) { case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); + case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); @@ -457,19 +457,25 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 } Value* generator::bf16_to_fp32(Value *in0){ - Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2)); - ret = insert_elt(ret, in0, (uint64_t)1); - ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0); - return bit_cast(ret, builder_->getFloatTy()); + if (tgt_->as_nvidia()->sm() >= 80) { + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), + "cvt.rn.f32.bf16 $0, $1;", "=r,h", false); + return call(ptx, {in0}); + } else { + Value *ret = UndefValue::get(vec_ty(i16_ty, 2)); + ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1); + ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0); + return bit_cast(ret, f32_ty); + } } Value* generator::fp32_to_bf16(Value *in0){ if(tgt_->as_nvidia()->sm() >= 80){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false), + InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false), "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); return call(ptx, {in0}); } - return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1); + return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1); } /** @@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) { if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ // FP32 -> BF16 if(op_sca_ty->is_fp32_ty()) - for(size_t i = 0; i < x_idxs.size(); i++) - vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + // for(size_t i = 0; i < x_idxs.size(); i++) + // vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + for (indices_t idx: idxs_.at(x)) { + Value *arg = vals_[x->get_operand(0)][idx]; + vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); + } // BF16 -> FP32 if(ret_sca_ty->is_fp32_ty()) - for(size_t i = 0; i < x_idxs.size(); i++) - vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); + for(size_t i = 0; i < x_idxs.size(); i++) + vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); return; } @@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){ // --- std::vector ret_tys(n_words, IntegerType::get(*ctx_, width)); Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; + // ret_ty->print(llvm::outs()); std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); @@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){ } auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); + if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store + ty = f16_ty; for(size_t i = 0; i < idxs.size(); i += vec){ auto idx = idxs[i]; // pointer Value *ptr = vals_[ptr_op][idx]; - ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1)); + // vectorize + Type *v_ty = vec_ty(ty, vec); + ptr = bit_cast(ptr, v_ty->getPointerTo(1)); // value - Value* val = UndefValue::get(vec_ty(ty, vec)); + Value* val = UndefValue::get(v_ty); for(size_t ii = 0; ii < vec; ii++) - val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii); + val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii); if(mx){ Value *msk = vals_[mx->get_mask_operand()][idx]; Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); @@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va vals_[C][idxs_[C][i]] = acc[i]; } +namespace { +class mma16816_smem_loader { +public: + mma16816_smem_loader(int wpt, std::vector order, int k_order, + std::vector tile_shape, + std::vector instr_shape, std::vector mat_shape, + int per_phase, int max_phase, int dtsize, Builder *builder, + adder add, multiplier mul, geper gep) + : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), + instr_shape_(instr_shape), mat_shape_(mat_shape), + per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), + add(add), mul(mul), gep(gep) { + // compute compile-time constant variables & types + c_mat_shape_ = mat_shape[order[0]]; + s_mat_shape_ = mat_shape[order[1]]; + + c_stride_ = tile_shape[order[1]]; + s_stride_ = tile_shape[order[0]]; + + // rule: k must be the fast-changing axis + need_trans_ = k_order_ != order_[0]; + can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); + + // std::cout << can_use_ldmatrix_ << std::endl; + // std::cout << need_trans_ << std::endl; + + // we need more pointers at the fast-changing axis, + if (can_use_ldmatrix_) + num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; + else // warning: this only works for tf32 & need transpose + num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; + num_ptr_ = std::max(num_ptr_, 2); + + + // load_v4 stride (in num of mats) + int load_stride_in_mat[2]; + load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2 + load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]); + p_load_stride_in_mat_ = load_stride_in_mat[order[0]]; + // stride in mat, used by load_v4 + s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]); + } + + std::vector compute_offs(Value *warp_off, Value *lane) { + // TODO: this needs to be moved to constructor (and extracted to arr_order) + mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_; + warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1]; + // start matrix logic offset (rename it as base_mat_off?) + Value *mat_off[2] = {nullptr, nullptr}; + + if (can_use_ldmatrix_) { + // c: lane idx inside a group (a group is a collection of 8 contiguous threads) + // s: group idx (0,1,2,3) inside a warp + Value *c = urem(lane, i32(8)); + Value *s = udiv(lane, i32(8)); + // We can decompose s => s_0, s_1... + Value *s0 = urem(s, i32(2)); + Value *s1 = udiv(s, i32(2)); + + // We use different orders for a & b for better performance. + Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; + Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; + mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + mat_off[k_order_] = k_mat_arr; + // physical offset (before swizzling) + Value *c_mat_off = mat_off[order_[0]]; + Value *s_mat_off = mat_off[order_[1]]; + // offset inside a matrix + Value *s_off_in_mat = c; + + std::vector offs(num_ptr_); + Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + // pre-compute strided offset + Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); + for (int i=0; i < num_ptr_; ++i) { + Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_)); + c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle + offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_))); + } + return offs; + } else if (dtsize_ == 4 && need_trans_) { + // load tf32 matrices with lds32 + Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] + Value *s_off_in_mat = urem(lane, i32(4)); // + + Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + std::vector offs(num_ptr_); + for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time + int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; + int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; + if (k_mat_arr_int > 0) // we don't need pointers for k + continue; + Value *k_mat_arr = i32(k_mat_arr_int); + Value *nk_mat_arr = i32(nk_mat_arr_int); + // physical offset (before swizzling) + Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + Value *s_mat_off = k_mat_arr; // always 0? + Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); + // FIXME: (k_order_ == 1?) is really dirty hack + for (int i = 0; i < num_ptr_/2; ++i) { + Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2))); + c_mat_off_i = xor_(c_mat_off_i, phase); + Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); + // TODO: move this out of the loop + c_off = urem(c_off, i32(tile_shape_[order_[0]])); + s_off = urem(s_off, i32(tile_shape_[order_[1]])); + offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_))); + } + } + return offs; + // throw std::runtime_error("not implemented"); + } else + throw std::runtime_error("invalid smem load config"); + } + + std::tuple + load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, + Value *pre_ptr, Value *next_ptr, std::vector &off, std::vector &ptrs, + FunctionType *ldmatrix_ty, Type *smem_ptr_ty, + std::map> &prefetch_latch_to_bb_) { + assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); + int mat_idx[2] = {mat0, mat1}; + int k = mat_idx[k_order_]; + + int ptr_idx = -1; + if (can_use_ldmatrix_) + ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); + else // tf32 & trans + ptr_idx = mat_idx[order_[0]]; + + auto get_ptr = [&](int idx) -> Value* { + Value *ptr = nullptr; + if (k == 0 && is_prefetch) { + if (inc == 0) + ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty); + else + ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty); + } else + ptr = ptrs.at(idx); + return ptr; + }; + Value *ptr = get_ptr(ptr_idx); + + Value *res_v4 = nullptr; + if (can_use_ldmatrix_) { + std::string trans = need_trans_ ? ".trans" : ""; + // the offset (in byte) on the strided axis is a constant + int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; + InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, + "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " + "{$0, $1, $2, $3}, " + "[$4 + " + std::to_string(s_offset) + "];", + "=r,=r,=r,=r,r", true); + assert(ptr); + res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); + if (k == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); + return {extract_val(res_v4, std::vector{0}), + extract_val(res_v4, std::vector{1}), + extract_val(res_v4, std::vector{2}), + extract_val(res_v4, std::vector{3})}; + } else { + // assert(false && "should not be here"); + assert(dtsize_ == 4 && need_trans_); + Value *ptr2 = get_ptr(ptr_idx+1); + assert(s_mat_stride_ == 1); + int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; + int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; + Value *elem0, *elem1, *elem2, *elem3; + if (k_order_ == 1) { + elem0 = load(gep(ptr, i32(s_offset_elem))); + elem1 = load(gep(ptr2, i32(s_offset_elem))); + elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); + elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); + } else { // for b (k first) + elem0 = load(gep(ptr, i32(s_offset_elem))); + elem2 = load(gep(ptr2, i32(s_offset_elem))); + elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); + elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); + } + if (k == 0 && inc == 1 && is_prefetch) { + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); + } + return {elem0, elem1, elem2, elem3}; + } + } + + int get_num_ptr() const { return num_ptr_; } + +private: + int wpt_; + std::vector order_; + int k_order_; + std::vector tile_shape_; + std::vector instr_shape_; + std::vector mat_shape_; + int per_phase_, max_phase_; + int dtsize_; + + // generated + int c_mat_shape_, s_mat_shape_; + int c_stride_, s_stride_; + // p_: on the pointer axis + int p_load_stride_in_mat_; + int s_mat_stride_; + // stride when moving to next not-k mat + int warp_off_stride_; + int mat_arr_stride_; // matrix arrangement (inside a load) stride + bool need_trans_, can_use_ldmatrix_; + int num_ptr_; + + Builder *builder_; + adder add; + multiplier mul; + geper gep; +}; +} + /** * \brief Code Generation for `mma.16816` (A100) */ @@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); bool is_a_row = ord_a[0] == 1; bool is_b_row = ord_b[0] == 1; - std::string a_trans = is_a_row ? "" : ".trans"; - std::string b_trans = is_b_row ? ".trans" : ""; - int stride_a_m = is_a_row ? shape_a[1] : 1; - int stride_a_k = is_a_row ? 1 : shape_a[0]; - int stride_b_n = is_b_row ? 1 : shape_b[0]; - int stride_b_k = is_b_row ? shape_b[1] : 1; - int stride_a0 = is_a_row ? stride_a_k : stride_a_m; - int stride_a1 = is_a_row ? stride_a_m : stride_a_k; - int stride_b0 = is_b_row ? stride_b_n : stride_b_k; - int stride_b1 = is_b_row ? stride_b_k : stride_b_n; - int lda = is_a_row ? stride_a_m : stride_a_k; - int ldb = is_b_row ? stride_b_k : stride_b_n; - int per_phase_a = swizzle_->get_per_phase(layout_a); - int max_phase_a = swizzle_->get_max_phase(layout_a); - int per_phase_b = swizzle_->get_per_phase(layout_b); - int max_phase_b = swizzle_->get_max_phase(layout_b); - int num_ptr_a = 8; - int num_ptr_b = 8; - int vec_a = 8; - int vec_b = 8; + + std::vector mma_instr_shape = layout->get_mma_instr_shape(); + const int mma_instr_m = mma_instr_shape[0]; + const int mma_instr_n = mma_instr_shape[1]; + const int mma_instr_k = mma_instr_shape[2]; + + std::vector mat_shape = layout->get_mma_mat_shape(); + const int mat_shape_m = mat_shape[0]; + const int mat_shape_n = mat_shape[1]; + const int mat_shape_k = mat_shape[2]; + + const int per_phase_a = swizzle_->get_per_phase(layout_a); + const int max_phase_a = swizzle_->get_max_phase(layout_a); + const int per_phase_b = swizzle_->get_per_phase(layout_b); + const int max_phase_b = swizzle_->get_max_phase(layout_b); + + const int num_rep_m = shapes[0] / layout->shape_per_cta(0); + const int num_rep_n = shapes[1] / layout->shape_per_cta(1); + const int num_rep_k = std::max(NK/mma_instr_k, 1); Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); + Type *bf16x2_ty = vec_ty(bf16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); + Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); - FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{ptr_ty(f16_ty, 3)}, false); + + FunctionType *ldmatrix_ty = nullptr; + FunctionType *mma_ty = nullptr; + Type *phi_ty = nullptr; + Type *smem_ptr_ty = nullptr; + + ir::type *A_ir_ty = A->get_type()->get_scalar_ty(); + ir::type *B_ir_ty = B->get_type()->get_scalar_ty(); + if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) { + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(f16_ty, 3); + ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp16x2_ty; + } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { + // FIXME: We should use bf16 here. + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(f16_ty, 3); + ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp16x2_ty; + // mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + // smem_ptr_ty = ptr_ty(bf16_ty, 3); + // ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + // phi_ty = bf16x2_ty; + } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(fp32_ty, 3); + ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp32_ty; + } else + throw std::runtime_error("mma16816 data type not supported"); // left-hand-side values - std::map, std::pair> ha; + std::map, Value*> ha; std::map, Value*> hb; BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value *lane = urem(thread, i32(32)); Value *warp = udiv(thread, i32(32)); - Value *warp12 = udiv(warp, i32(layout->wpt(0))); - Value *warp0 = urem(warp, i32(layout->wpt(0))); - Value *warp1 = urem(warp12, i32(layout->wpt(1))); + Value *warp_mn = udiv(warp, i32(layout->wpt(0))); + Value *warp_m = urem(warp, i32(layout->wpt(0))); + Value *warp_n = urem(warp_mn, i32(layout->wpt(1))); std::vector& fc = fcs.begin()->second; - Value *tidr8 = urem(lane, i32(8)); - Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a)); - Value* off_a0 = mul(tidr8, i32(lda)); - Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8)); - Value *off_ak = mul(udiv(lane, i32(16)), i32(8)); - off_am = urem(off_am, i32(shape_a[0])); - off_ak = urem(off_ak, i32(shape_a[1])); - off_a0 = add(off_a0, is_a_row ? off_ak : off_am); - Value* off_a1 = is_a_row ? off_am : off_ak; - std::vector off_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++){ - Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0)))); - off_a0i = exact_udiv(off_a0i, i32(vec_a)); - off_a0i = xor_(off_a0i, phase_a); - off_a0i = mul(off_a0i, i32(vec_a)); - off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); - } + size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b)); - Value* off_b0 = mul(tidr8, i32(ldb)); - Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8)); - Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8)); - off_bn = urem(off_bn, i32(shape_b[1])); - off_bk = urem(off_bk, i32(shape_b[0])); - off_b0 = add(off_b0, is_b_row ? off_bn : off_bk); - Value* off_b1 = is_b_row ? off_bk : off_bn; - std::vector off_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++){ - Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16))); - off_b0i = exact_udiv(off_b0i, i32(vec_b)); - off_b0i = xor_(off_b0i, phase_b); - off_b0i = mul(off_b0i, i32(vec_b)); - off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); - } + // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride + // v (s0_0(0), s1_0(2), | *num_rep_k + // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) + // ----------- + // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) + mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, + {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, + per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); + std::vector off_a = a_loader.compute_offs(warp_m, lane); + int num_ptr_a = a_loader.get_num_ptr(); + + // | -> n (col-major) + // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n + // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) + // ----------- + // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) + mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b, + {mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n}, + per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); + std::vector off_b = b_loader.compute_offs(warp_n, lane); + int num_ptr_b = b_loader.get_num_ptr(); builder_->SetInsertPoint(CurrBB); // A pointer std::vector ptrs_a(num_ptr_a); for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = gep(shmems_[A], {off_a[i]}); + ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); // B pointer std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) - ptrs_b[i] = gep(shmems_[B], {off_b[i]}); + ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{$0, $1, $2, $3}, " - "{$4, $5, $6, $7}, " - "{$8, $9}, " - "{$10, $11, $12, $13};", + InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + + " {$0, $1, $2, $3}," + " {$4, $5, $6, $7}," + " {$8, $9}," + " {$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); - unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0); - unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1); - - // create mma & unpack result - auto call_mma = [&](unsigned m, unsigned n, unsigned K) { - unsigned cols_per_thread = num_rep_0 * 2; + // create mma & unpack result, m, n, k are offsets in mat + auto call_mma = [&](unsigned m, unsigned n, unsigned k) { + unsigned cols_per_thread = num_rep_m * 2; std::vector idx = { - (m*2 + 0) + (n*2 + 0)*cols_per_thread, - (m*2 + 0) + (n*2 + 1)*cols_per_thread, - (m*2 + 1) + (n*2 + 0)*cols_per_thread, - (m*2 + 1) + (n*2 + 1)*cols_per_thread + (m + 0) + (n*2 + 0)*cols_per_thread, + (m + 0) + (n*2 + 1)*cols_per_thread, + (m + 1) + (n*2 + 0)*cols_per_thread, + (m + 1) + (n*2 + 1)*cols_per_thread }; - Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, - hb[{n, K}], hb[{n, K+8}], - fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); + Value *nc = call(mma_ty, mma_fn, + {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], + hb[{n, k}], hb[{n, k+1}], + fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); fc[idx[0]] = extract_val(nc, std::vector{0}); fc[idx[1]] = extract_val(nc, std::vector{1}); fc[idx[2]] = extract_val(nc, std::vector{2}); @@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ir::phi_node* phiA = dynamic_cast(A); ir::phi_node* phiB = dynamic_cast(B); - auto register_lds = - [&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { - if (K <= 8 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); - } else - vals[{m, K}] = {val0, val1}; - }; - auto register_lds2 = - [&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) { - if (K <= 8 && is_prefetch) { + [&](std::map, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) { + if (k < 2 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block)); } else - vals[{m, K}] = val; + vals[{n, k}] = val; }; - auto load_a = [&](int m, int K, int inc, bool is_prefetch) { - int offidx = (is_a_row ? K/16 : m) % num_ptr_a; - Value* ptra; - if(K == 0 && is_prefetch){ - if(inc == 0) - ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); - else - ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); - } - else - ptra = ptrs_a[offidx]; - int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); - int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K; - InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", - "=r,=r,=r,=r,r", true); - Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra}); - if(K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa); - Value *ha0 = extract_val(haa, std::vector{0}); - Value *ha1 = extract_val(haa, std::vector{1}); - Value *ha2 = extract_val(haa, std::vector{2}); - Value *ha3 = extract_val(haa, std::vector{3}); - register_lds(ha, m, K, inc, ha0, ha1, is_prefetch); - register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch); + auto load_a = [&](int m, int k, int inc, bool is_prefetch) { + auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], + shared_next_ptr_[layout_a], off_a, ptrs_a, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(ha, m, k, inc, ha0, is_prefetch); + register_lds2(ha, m+1, k, inc, ha1, is_prefetch); + register_lds2(ha, m, k+1, inc, ha2, is_prefetch); + register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); }; - auto load_b = [&](int n, int K, int inc, bool is_prefetch) { - int offidx = (is_b_row ? n : K/16) % num_ptr_b; - Value* ptrb; - if(K == 0 && is_prefetch){ - if(inc == 0) - ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); - else - ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); - } - else - ptrb = ptrs_b[offidx]; - int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; - int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8); - InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", - "=r,=r,=r,=r,r", true); - Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb}); - if(K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb); - Value *hb0 = extract_val(hbb, std::vector{0}); - Value *hb1 = extract_val(hbb, std::vector{1}); - Value *hb2 = extract_val(hbb, std::vector{2}); - Value *hb3 = extract_val(hbb, std::vector{3}); - register_lds2(hb, n, K, inc, hb0, is_prefetch); - register_lds2(hb, n+1, K, inc, hb2, is_prefetch); - register_lds2(hb, n, K+8, inc, hb1, is_prefetch); - register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch); + auto load_b = [&](int n, int k, int inc, bool is_prefetch) { + auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], + shared_next_ptr_[layout_b], off_b, ptrs_b, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(hb, n, k, inc, hb0, is_prefetch); + register_lds2(hb, n+1, k, inc, hb2, is_prefetch); + register_lds2(hb, n, k+1, inc, hb1, is_prefetch); + register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); }; if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); - for(unsigned m = 0; m < num_rep_0; m++){ - ha[{m, 0}].first = phi(fp16x2_ty, 2); - ha[{m, 0}].second = phi(fp16x2_ty, 2); - ha[{m, 8}].first = phi(fp16x2_ty, 2); - ha[{m, 8}].second = phi(fp16x2_ty, 2); + for(unsigned m = 0; m < num_rep_m; m++){ + ha[{2*m, 0}] = phi(phi_ty, 2); + ha[{2*m+1, 0}] = phi(phi_ty, 2); + ha[{2*m, 1}] = phi(phi_ty, 2); + ha[{2*m+1, 1}] = phi(phi_ty, 2); } - for(unsigned n = 0; n < num_rep_1; n+=2){ - hb[{n, 0}] = phi(fp16x2_ty, 2); - hb[{n+1, 0}] = phi(fp16x2_ty, 2); - hb[{n, 8}] = phi(fp16x2_ty, 2); - hb[{n+1, 8}] = phi(fp16x2_ty, 2); + for(unsigned n = 0; n < num_rep_n; n+=2){ + hb[{n, 0}] = phi(phi_ty, 2); + hb[{n+1, 0}] = phi(phi_ty, 2); + hb[{n, 1}] = phi(phi_ty, 2); + hb[{n+1, 1}] = phi(phi_ty, 2); } // insert prefetched lds at the end of loop header builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); - for(unsigned m = 0; m < num_rep_0; m++) - load_a(m, 0, 0, true); - for(unsigned n = 0; n < num_rep_1; n+=2) + for(unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 0, 0, true); + for(unsigned n = 0; n < num_rep_n; n+=2) load_b(n, 0, 0, true); // update accumulators builder_->SetInsertPoint(CurrBB); - for(unsigned K = 0; K < NK; K += 16){ - int NEXTK = (K + 16) % NK; + for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2 + int next_k = (k + 1) % num_rep_k; // prefetch A - for(unsigned m = 0; m < num_rep_0; m++) - load_a(m, NEXTK, 1, true); + for(unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 2*next_k, 1, true); // prefetch B - for(unsigned n = 0; n < num_rep_1; n+=2) - load_b(n, NEXTK, 1, true); + for(unsigned n = 0; n < num_rep_n; n+=2) + load_b(n, 2*next_k, 1, true); // tensor core ops - for(unsigned m = 0; m < num_rep_0; m++) - for(unsigned n = 0; n < num_rep_1; n++){ - call_mma(m, n, K); + for(unsigned m = 0; m < num_rep_m; m++) + for(unsigned n = 0; n < num_rep_n; n++){ + call_mma(2*m, n, 2*k); } } } else{ - for(unsigned K = 0; K < NK; K += 16) - for(unsigned m = 0; m < num_rep_0; m++) - for(unsigned n = 0; n < num_rep_1; n++){ - if(ha.find({m, K}) == ha.end()) - load_a(m, K, 0, false); - if(hb.find({n, K})==hb.end()) - load_b(n, K, 0, false); - call_mma(m, n, K); - } + for (unsigned k = 0; k < num_rep_k; k++) { + for (unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 2*k, 0, /*is_prefetch*/false); + for (unsigned n = 0; n < num_rep_n; n+=2) + load_b(n, 2*k, 0, /*is_prefetch*/false); + for (unsigned m = 0; m < num_rep_m; m++) + for (unsigned n = 0; n < num_rep_n; n++) + call_mma(2*m, n, 2*k); + } } // write back unsigned i = 0; @@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) - return visit_mma16816(dot, A, B, D, NK); + return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); } @@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); if(ty->getPrimitiveSizeInBits() <= 32) return call(shfl, {acc, i32(i)}); - acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); + acc = bit_cast(acc, vec_ty(f32_ty, 2)); Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); - return builder_->CreateBitCast(ret, ty); + return bit_cast(ret, ty); } /** @@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); // pointer to temporary shared memory Type *ty = cvt(out->get_type()->get_scalar_ty()); + + if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store + ty = f16_ty; + // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); @@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // stash value to shared mem indices_t idxs = {in_ax[0][i*max_ii + ii], in_ax[1][j*max_jj + jj]}; - store(vals_[in][idxs], ptr); + store(bit_cast(vals_[in][idxs], ty), ptr); } add_barrier(); max_ii = out_ax[0].size()/n_reps[0]; diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 864fea85c..bae8fe828 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name())); + ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name())); add->replace_all_uses_with(new_dot); return true; } diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index c54a92005..f94d8f368 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) { std::vector to_prefetch; ir::for_each_instruction(mod, [&](ir::instruction *i) { if (auto *dot = dynamic_cast(i)) { - // Now only do prefetching when dot is fp16 - if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID) + // Now only do prefetching when dot is using tensor cores + if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || + dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || + (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() + && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) + ) + ) return; auto *a = dynamic_cast(dot->get_operand(0)); auto *b = dynamic_cast(dot->get_operand(1)); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 7248d6cec..910f23a1a 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -121,6 +121,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; + // module->print(llvm::outs(), nullptr); pm.add(llvm::createVerifierPass()); pm.run(*module); // create machine diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index a8ba68d1c..bf59b5b39 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -367,8 +367,8 @@ value *builder::create_log(value *arg){ return insert(log_inst::create(arg)); } -value *builder::create_dot(value *A, value *B, value *C) { - return insert(dot_inst::create_nn(A, B, C)); +value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) { + return insert(dot_inst::create_nn(A, B, C, allow_tf32)); } value *builder::create_trans(value *A, const std::vector& perm) { diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index a46c34b4f..12be53bfc 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -718,12 +718,13 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask // Linear Algebra //===----------------------------------------------------------------------===// -ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { +ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { ir::value *_0 = builder->get_float32(0); unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned N = rhs->get_type()->get_block_shapes()[1]; _0 = builder->create_splat(_0, {M, N}); - return builder->create_dot(lhs, rhs, _0); + bool _allow_tf32 = allow_tf32->get_value() != 0; + return builder->create_dot(lhs, rhs, _0, _allow_tf32); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 0206b7e77..6e416a43e 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct // matmul_inst classes //===----------------------------------------------------------------------===// -dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, +dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next) : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { set_operand(0, A); set_operand(1, B); set_operand(2, C); + allow_tf32_ = allow_tf32; } instruction *dot_inst::create(value *A, value *B, value *C, - bool AT, bool BT, + bool AT, bool BT, bool allow_tf32, const std::string &name, instruction *next) { TransT OPA = AT ? Trans : NoTrans; TransT OPB = BT ? Trans : NoTrans; - return new dot_inst(A, B, C, OPA, OPB, name, next); + return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next); } -instruction *dot_inst::create_nn(value *A, value *B, value *C, +instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, NoTrans, name, next); + return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next); } -instruction *dot_inst::create_nt(value *A, value *B, value *C, +instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, Trans, name, next); + return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next); } -instruction *dot_inst::create_tn(value *A, value *B, value *C, +instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, NoTrans, name, next); + return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next); } -instruction *dot_inst::create_tt(value *A, value *B, value *C, +instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, Trans, name, next); + return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next); } //===----------------------------------------------------------------------===// diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d8e88a609..e32622005 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -10,6 +10,7 @@ import torch from numpy.random import RandomState import triton +import triton._C.libtriton.triton as _triton import triton.language as tl from triton.code_gen import TensorWrapper, reinterpret @@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # --------------- -@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, device='cuda'): +@pytest.mark.parametrize("epilogue, allow_tf32", + [(epilogue, allow_tf32) + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for allow_tf32 in [True, False]]) +def test_dot(epilogue, allow_tf32, device='cuda'): # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, + ALLOW_TF32: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - z = tl.dot(tl.load(Xs), tl.load(Ys)) + z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32) if ADD_MATRIX: z += tl.load(Zs) if ADD_ROWS: @@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'): rs = RandomState(17) x = numpy_random((M, K), dtype_str='float32', rs=rs) y = numpy_random((K, N), dtype_str='float32', rs=rs) + if allow_tf32: + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80: + pytest.skip("Only test tf32 on devices with sm >= 80") + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) # triton result @@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'): BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols') + ADD_COLS=epilogue == 'add-cols', + ALLOW_TF32=allow_tf32) # torch result z_ref = np.matmul(x, y) if epilogue == 'add-matrix': @@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'): ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + if allow_tf32: + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx def test_dot_without_load(): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d32da45c3..4ef6408d0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None): @builtin -def dot(input, other, _builder=None): +def dot(input, other, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. The two blocks must be two dimensionals and have compatible inner dimensions. :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`} + :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} + :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - return frontend.dot(input, other, _builder) + return frontend.dot(input, other, allow_tf32, _builder) # ----------------------- From bf32205edc00ecffc557884ae4699fae4ce91e63 Mon Sep 17 00:00:00 2001 From: Botao Yu Date: Wed, 12 Jan 2022 03:07:16 +0800 Subject: [PATCH 046/215] [OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425) --- python/triton/ops/blocksparse/matmul.py | 6 +++--- python/triton/ops/blocksparse/softmax.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 9a04ded66..48efe7ea3 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -283,14 +283,14 @@ def dsd_lut(layout, block, step, trans, device): # ------------------------------- # same as above, except that the increments are in the sparse memory layout if trans: - A_idx = torch.arange(num_blocks) + A_idx = torch.arange(num_blocks, device=layout.device) else: A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) current_offset = 0 for z in range(layout.size(0)): - layoutw = layout[z, :, :].clone() + layoutw = layout[z, :, :].clone().long() msum = layoutw.sum() - layoutw[layoutw > 0] = 1 + torch.arange(msum) + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) current_offset += msum A_incs = A_idx * block * block diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index f9d49ae56..6ac76dcc4 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -106,19 +106,18 @@ def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexp class _softmax(torch.autograd.Function): @staticmethod def make_lut(layout, block, device): - _empty = torch.tensor([], dtype=torch.int64, device=layout.device) - sizes = _empty.clone() # sizes along rows - for h in range(layout.shape[0]): - sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + sizes = layout.sum(-1).view(-1) # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) # block indices - idx = torch.arange(layout.sum()) - head = layout.nonzero(as_tuple=False)[:, 0] - rows = layout.nonzero(as_tuple=False)[:, 1] - columns = layout.nonzero(as_tuple=False)[:, 2] + layout_sum = sizes.sum() + idx = torch.arange(layout_sum, device=layout.device) + layout_nonzero = layout.nonzero(as_tuple=False) + head = layout_nonzero[:, 0] + rows = layout_nonzero[:, 1] + columns = layout_nonzero[:, 2] core = torch.stack((idx, columns, rows, head), dim=1).view(-1) # construct look-up table offsets = offsets * 4 + 2 * sizes.numel() From bbc78f651600735bbcc98de3998b837dd0ce68c1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 11 Jan 2022 11:08:49 -0800 Subject: [PATCH 047/215] [FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427) --- python/triton/language/random.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 6f3645b41..69d7f4c4d 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -76,7 +76,7 @@ def uint32_to_uniform_float(x): """ Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - two_to_the_minus_32 = 2.328306e-10 + two_to_the_minus_32: tl.constexpr = 2.328306e-10 return x * two_to_the_minus_32 @@ -89,6 +89,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ + offset = offset.to(tl.uint32, bitcast=True) source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) @@ -102,6 +103,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ + offsets = offsets.to(tl.uint32, bitcast=True) i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) u1 = uint32_to_uniform_float(i1) u2 = uint32_to_uniform_float(i2) From 4c94359199d02839cacc7674f61f7d0631ce45a9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 11 Jan 2022 23:11:58 -0800 Subject: [PATCH 048/215] [FRONTEND] Alignment fix-up (#428) --- include/triton/external/CUDA/cuda.h | 6209 +++++++++++++++++++++++---- python/src/triton.cc | 31 +- python/triton/code_gen.py | 14 +- python/triton/language/core.py | 3 +- 4 files changed, 5386 insertions(+), 871 deletions(-) diff --git a/include/triton/external/CUDA/cuda.h b/include/triton/external/CUDA/cuda.h index 24d96bd6c..f7bf9fc12 100755 --- a/include/triton/external/CUDA/cuda.h +++ b/include/triton/external/CUDA/cuda.h @@ -60,9 +60,6 @@ typedef uint32_t cuuint32_t; typedef uint64_t cuuint64_t; #endif -/** - * CUDA API versioning support - */ #if defined(__CUDA_API_VERSION_INTERNAL) || defined(__DOXYGEN_ONLY__) || defined(CUDA_ENABLE_DEPRECATED) #define __CUDA_DEPRECATED #elif defined(_MSC_VER) @@ -74,14 +71,8 @@ typedef uint64_t cuuint64_t; #endif #if defined(CUDA_FORCE_API_VERSION) - #if (CUDA_FORCE_API_VERSION == 3010) - #define __CUDA_API_VERSION 3010 - #else - #error "Unsupported value of CUDA_FORCE_API_VERSION" - #endif -#else - #define __CUDA_API_VERSION 10000 -#endif /* CUDA_FORCE_API_VERSION */ +#error "CUDA_FORCE_API_VERSION is no longer supported." +#endif #if defined(__CUDA_API_VERSION_INTERNAL) || defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) #define __CUDA_API_PER_THREAD_DEFAULT_STREAM @@ -92,74 +83,66 @@ typedef uint64_t cuuint64_t; #define __CUDA_API_PTSZ(api) api #endif -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION >= 3020 - #define cuDeviceTotalMem cuDeviceTotalMem_v2 - #define cuCtxCreate cuCtxCreate_v2 - #define cuModuleGetGlobal cuModuleGetGlobal_v2 - #define cuMemGetInfo cuMemGetInfo_v2 - #define cuMemAlloc cuMemAlloc_v2 - #define cuMemAllocPitch cuMemAllocPitch_v2 - #define cuMemFree cuMemFree_v2 - #define cuMemGetAddressRange cuMemGetAddressRange_v2 - #define cuMemAllocHost cuMemAllocHost_v2 - #define cuMemHostGetDevicePointer cuMemHostGetDevicePointer_v2 - #define cuMemcpyHtoD __CUDA_API_PTDS(cuMemcpyHtoD_v2) - #define cuMemcpyDtoH __CUDA_API_PTDS(cuMemcpyDtoH_v2) - #define cuMemcpyDtoD __CUDA_API_PTDS(cuMemcpyDtoD_v2) - #define cuMemcpyDtoA __CUDA_API_PTDS(cuMemcpyDtoA_v2) - #define cuMemcpyAtoD __CUDA_API_PTDS(cuMemcpyAtoD_v2) - #define cuMemcpyHtoA __CUDA_API_PTDS(cuMemcpyHtoA_v2) - #define cuMemcpyAtoH __CUDA_API_PTDS(cuMemcpyAtoH_v2) - #define cuMemcpyAtoA __CUDA_API_PTDS(cuMemcpyAtoA_v2) - #define cuMemcpyHtoAAsync __CUDA_API_PTSZ(cuMemcpyHtoAAsync_v2) - #define cuMemcpyAtoHAsync __CUDA_API_PTSZ(cuMemcpyAtoHAsync_v2) - #define cuMemcpy2D __CUDA_API_PTDS(cuMemcpy2D_v2) - #define cuMemcpy2DUnaligned __CUDA_API_PTDS(cuMemcpy2DUnaligned_v2) - #define cuMemcpy3D __CUDA_API_PTDS(cuMemcpy3D_v2) - #define cuMemcpyHtoDAsync __CUDA_API_PTSZ(cuMemcpyHtoDAsync_v2) - #define cuMemcpyDtoHAsync __CUDA_API_PTSZ(cuMemcpyDtoHAsync_v2) - #define cuMemcpyDtoDAsync __CUDA_API_PTSZ(cuMemcpyDtoDAsync_v2) - #define cuMemcpy2DAsync __CUDA_API_PTSZ(cuMemcpy2DAsync_v2) - #define cuMemcpy3DAsync __CUDA_API_PTSZ(cuMemcpy3DAsync_v2) - #define cuMemsetD8 __CUDA_API_PTDS(cuMemsetD8_v2) - #define cuMemsetD16 __CUDA_API_PTDS(cuMemsetD16_v2) - #define cuMemsetD32 __CUDA_API_PTDS(cuMemsetD32_v2) - #define cuMemsetD2D8 __CUDA_API_PTDS(cuMemsetD2D8_v2) - #define cuMemsetD2D16 __CUDA_API_PTDS(cuMemsetD2D16_v2) - #define cuMemsetD2D32 __CUDA_API_PTDS(cuMemsetD2D32_v2) - #define cuArrayCreate cuArrayCreate_v2 - #define cuArrayGetDescriptor cuArrayGetDescriptor_v2 - #define cuArray3DCreate cuArray3DCreate_v2 - #define cuArray3DGetDescriptor cuArray3DGetDescriptor_v2 - #define cuTexRefSetAddress cuTexRefSetAddress_v2 - #define cuTexRefGetAddress cuTexRefGetAddress_v2 - #define cuGraphicsResourceGetMappedPointer cuGraphicsResourceGetMappedPointer_v2 -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION >= 3020 */ -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION >= 4000 - #define cuCtxDestroy cuCtxDestroy_v2 - #define cuCtxPopCurrent cuCtxPopCurrent_v2 - #define cuCtxPushCurrent cuCtxPushCurrent_v2 - #define cuStreamDestroy cuStreamDestroy_v2 - #define cuEventDestroy cuEventDestroy_v2 -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION >= 4000 */ -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION >= 4010 - #define cuTexRefSetAddress2D cuTexRefSetAddress2D_v3 -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION >= 4010 */ -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION >= 6050 - #define cuLinkCreate cuLinkCreate_v2 - #define cuLinkAddData cuLinkAddData_v2 - #define cuLinkAddFile cuLinkAddFile_v2 -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION >= 6050 */ -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION >= 6050 - #define cuMemHostRegister cuMemHostRegister_v2 - #define cuGraphicsResourceSetMapFlags cuGraphicsResourceSetMapFlags_v2 -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION >= 6050 */ - -#if !defined(__CUDA_API_VERSION_INTERNAL) -#if defined(__CUDA_API_VERSION) && __CUDA_API_VERSION >= 3020 && __CUDA_API_VERSION < 4010 - #define cuTexRefSetAddress2D cuTexRefSetAddress2D_v2 -#endif /* __CUDA_API_VERSION && __CUDA_API_VERSION >= 3020 && __CUDA_API_VERSION < 4010 */ -#endif /* __CUDA_API_VERSION_INTERNAL */ +#define cuDeviceTotalMem cuDeviceTotalMem_v2 +#define cuCtxCreate cuCtxCreate_v2 +#define cuCtxCreate_v3 cuCtxCreate_v3 +#define cuModuleGetGlobal cuModuleGetGlobal_v2 +#define cuMemGetInfo cuMemGetInfo_v2 +#define cuMemAlloc cuMemAlloc_v2 +#define cuMemAllocPitch cuMemAllocPitch_v2 +#define cuMemFree cuMemFree_v2 +#define cuMemGetAddressRange cuMemGetAddressRange_v2 +#define cuMemAllocHost cuMemAllocHost_v2 +#define cuMemHostGetDevicePointer cuMemHostGetDevicePointer_v2 +#define cuMemcpyHtoD __CUDA_API_PTDS(cuMemcpyHtoD_v2) +#define cuMemcpyDtoH __CUDA_API_PTDS(cuMemcpyDtoH_v2) +#define cuMemcpyDtoD __CUDA_API_PTDS(cuMemcpyDtoD_v2) +#define cuMemcpyDtoA __CUDA_API_PTDS(cuMemcpyDtoA_v2) +#define cuMemcpyAtoD __CUDA_API_PTDS(cuMemcpyAtoD_v2) +#define cuMemcpyHtoA __CUDA_API_PTDS(cuMemcpyHtoA_v2) +#define cuMemcpyAtoH __CUDA_API_PTDS(cuMemcpyAtoH_v2) +#define cuMemcpyAtoA __CUDA_API_PTDS(cuMemcpyAtoA_v2) +#define cuMemcpyHtoAAsync __CUDA_API_PTSZ(cuMemcpyHtoAAsync_v2) +#define cuMemcpyAtoHAsync __CUDA_API_PTSZ(cuMemcpyAtoHAsync_v2) +#define cuMemcpy2D __CUDA_API_PTDS(cuMemcpy2D_v2) +#define cuMemcpy2DUnaligned __CUDA_API_PTDS(cuMemcpy2DUnaligned_v2) +#define cuMemcpy3D __CUDA_API_PTDS(cuMemcpy3D_v2) +#define cuMemcpyHtoDAsync __CUDA_API_PTSZ(cuMemcpyHtoDAsync_v2) +#define cuMemcpyDtoHAsync __CUDA_API_PTSZ(cuMemcpyDtoHAsync_v2) +#define cuMemcpyDtoDAsync __CUDA_API_PTSZ(cuMemcpyDtoDAsync_v2) +#define cuMemcpy2DAsync __CUDA_API_PTSZ(cuMemcpy2DAsync_v2) +#define cuMemcpy3DAsync __CUDA_API_PTSZ(cuMemcpy3DAsync_v2) +#define cuMemsetD8 __CUDA_API_PTDS(cuMemsetD8_v2) +#define cuMemsetD16 __CUDA_API_PTDS(cuMemsetD16_v2) +#define cuMemsetD32 __CUDA_API_PTDS(cuMemsetD32_v2) +#define cuMemsetD2D8 __CUDA_API_PTDS(cuMemsetD2D8_v2) +#define cuMemsetD2D16 __CUDA_API_PTDS(cuMemsetD2D16_v2) +#define cuMemsetD2D32 __CUDA_API_PTDS(cuMemsetD2D32_v2) +#define cuArrayCreate cuArrayCreate_v2 +#define cuArrayGetDescriptor cuArrayGetDescriptor_v2 +#define cuArray3DCreate cuArray3DCreate_v2 +#define cuArray3DGetDescriptor cuArray3DGetDescriptor_v2 +#define cuTexRefSetAddress cuTexRefSetAddress_v2 +#define cuTexRefGetAddress cuTexRefGetAddress_v2 +#define cuGraphicsResourceGetMappedPointer cuGraphicsResourceGetMappedPointer_v2 +#define cuCtxDestroy cuCtxDestroy_v2 +#define cuCtxPopCurrent cuCtxPopCurrent_v2 +#define cuCtxPushCurrent cuCtxPushCurrent_v2 +#define cuStreamDestroy cuStreamDestroy_v2 +#define cuEventDestroy cuEventDestroy_v2 +#define cuTexRefSetAddress2D cuTexRefSetAddress2D_v3 +#define cuLinkCreate cuLinkCreate_v2 +#define cuLinkAddData cuLinkAddData_v2 +#define cuLinkAddFile cuLinkAddFile_v2 +#define cuMemHostRegister cuMemHostRegister_v2 +#define cuGraphicsResourceSetMapFlags cuGraphicsResourceSetMapFlags_v2 +#define cuStreamBeginCapture __CUDA_API_PTSZ(cuStreamBeginCapture_v2) +#define cuDevicePrimaryCtxRelease cuDevicePrimaryCtxRelease_v2 +#define cuDevicePrimaryCtxReset cuDevicePrimaryCtxReset_v2 +#define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2 +#define cuDeviceGetUuid_v2 cuDeviceGetUuid_v2 +#define cuIpcOpenMemHandle cuIpcOpenMemHandle_v2 +#define cuGraphInstantiate cuGraphInstantiate_v2 #if defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) #define cuMemcpy __CUDA_API_PTDS(cuMemcpy) @@ -181,14 +164,17 @@ typedef uint64_t cuuint64_t; #define cuStreamGetFlags __CUDA_API_PTSZ(cuStreamGetFlags) #define cuStreamGetCtx __CUDA_API_PTSZ(cuStreamGetCtx) #define cuStreamWaitEvent __CUDA_API_PTSZ(cuStreamWaitEvent) - #define cuStreamBeginCapture __CUDA_API_PTSZ(cuStreamBeginCapture) #define cuStreamEndCapture __CUDA_API_PTSZ(cuStreamEndCapture) #define cuStreamIsCapturing __CUDA_API_PTSZ(cuStreamIsCapturing) + #define cuStreamGetCaptureInfo __CUDA_API_PTSZ(cuStreamGetCaptureInfo) + #define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) + #define cuStreamUpdateCaptureDependencies __CUDA_API_PTSZ(cuStreamUpdateCaptureDependencies) #define cuStreamAddCallback __CUDA_API_PTSZ(cuStreamAddCallback) #define cuStreamAttachMemAsync __CUDA_API_PTSZ(cuStreamAttachMemAsync) #define cuStreamQuery __CUDA_API_PTSZ(cuStreamQuery) #define cuStreamSynchronize __CUDA_API_PTSZ(cuStreamSynchronize) #define cuEventRecord __CUDA_API_PTSZ(cuEventRecord) + #define cuEventRecordWithFlags __CUDA_API_PTSZ(cuEventRecordWithFlags) #define cuLaunchKernel __CUDA_API_PTSZ(cuLaunchKernel) #define cuLaunchHostFunc __CUDA_API_PTSZ(cuLaunchHostFunc) #define cuGraphicsMapResources __CUDA_API_PTSZ(cuGraphicsMapResources) @@ -205,7 +191,16 @@ typedef uint64_t cuuint64_t; #define cuSignalExternalSemaphoresAsync __CUDA_API_PTSZ(cuSignalExternalSemaphoresAsync) #define cuWaitExternalSemaphoresAsync __CUDA_API_PTSZ(cuWaitExternalSemaphoresAsync) + #define cuGraphUpload __CUDA_API_PTSZ(cuGraphUpload) #define cuGraphLaunch __CUDA_API_PTSZ(cuGraphLaunch) + #define cuStreamCopyAttributes __CUDA_API_PTSZ(cuStreamCopyAttributes) + #define cuStreamGetAttribute __CUDA_API_PTSZ(cuStreamGetAttribute) + #define cuStreamSetAttribute __CUDA_API_PTSZ(cuStreamSetAttribute) + #define cuMemMapArrayAsync __CUDA_API_PTSZ(cuMemMapArrayAsync) + + #define cuMemFreeAsync __CUDA_API_PTSZ(cuMemFreeAsync) + #define cuMemAllocAsync __CUDA_API_PTSZ(cuMemAllocAsync) + #define cuMemAllocFromPoolAsync __CUDA_API_PTSZ(cuMemAllocFromPoolAsync) #endif /** @@ -229,7 +224,7 @@ typedef uint64_t cuuint64_t; /** * CUDA API version number */ -#define CUDA_VERSION 10000 +#define CUDA_VERSION 11050 #ifdef __cplusplus extern "C" { @@ -239,34 +234,36 @@ extern "C" { * CUDA device pointer * CUdeviceptr is defined as an unsigned integer type whose size matches the size of a pointer on the target platform. */ -#if __CUDA_API_VERSION >= 3020 - #if defined(_WIN64) || defined(__LP64__) -typedef unsigned long long CUdeviceptr; +typedef unsigned long long CUdeviceptr_v2; #else -typedef unsigned int CUdeviceptr; +typedef unsigned int CUdeviceptr_v2; #endif +typedef CUdeviceptr_v2 CUdeviceptr; /**< CUDA device pointer */ -#endif /* __CUDA_API_VERSION >= 3020 */ - -typedef int CUdevice; /**< CUDA device */ -typedef struct CUctx_st *CUcontext; /**< CUDA context */ -typedef struct CUmod_st *CUmodule; /**< CUDA module */ -typedef struct CUfunc_st *CUfunction; /**< CUDA function */ -typedef struct CUarray_st *CUarray; /**< CUDA array */ -typedef struct CUmipmappedArray_st *CUmipmappedArray; /**< CUDA mipmapped array */ -typedef struct CUtexref_st *CUtexref; /**< CUDA texture reference */ -typedef struct CUsurfref_st *CUsurfref; /**< CUDA surface reference */ -typedef struct CUevent_st *CUevent; /**< CUDA event */ -typedef struct CUstream_st *CUstream; /**< CUDA stream */ -typedef struct CUgraphicsResource_st *CUgraphicsResource; /**< CUDA graphics interop resource */ -typedef unsigned long long CUtexObject; /**< An opaque value that represents a CUDA texture object */ -typedef unsigned long long CUsurfObject; /**< An opaque value that represents a CUDA surface object */ -typedef struct CUextMemory_st *CUexternalMemory; /**< CUDA external memory */ -typedef struct CUextSemaphore_st *CUexternalSemaphore; /**< CUDA external semaphore */ -typedef struct CUgraph_st *CUgraph; /**< CUDA graph */ -typedef struct CUgraphNode_st *CUgraphNode; /**< CUDA graph node */ -typedef struct CUgraphExec_st *CUgraphExec; /**< CUDA executable graph */ +typedef int CUdevice_v1; /**< CUDA device */ +typedef CUdevice_v1 CUdevice; /**< CUDA device */ +typedef struct CUctx_st *CUcontext; /**< CUDA context */ +typedef struct CUmod_st *CUmodule; /**< CUDA module */ +typedef struct CUfunc_st *CUfunction; /**< CUDA function */ +typedef struct CUarray_st *CUarray; /**< CUDA array */ +typedef struct CUmipmappedArray_st *CUmipmappedArray; /**< CUDA mipmapped array */ +typedef struct CUtexref_st *CUtexref; /**< CUDA texture reference */ +typedef struct CUsurfref_st *CUsurfref; /**< CUDA surface reference */ +typedef struct CUevent_st *CUevent; /**< CUDA event */ +typedef struct CUstream_st *CUstream; /**< CUDA stream */ +typedef struct CUgraphicsResource_st *CUgraphicsResource; /**< CUDA graphics interop resource */ +typedef unsigned long long CUtexObject_v1; /**< An opaque value that represents a CUDA texture object */ +typedef CUtexObject_v1 CUtexObject; /**< An opaque value that represents a CUDA texture object */ +typedef unsigned long long CUsurfObject_v1; /**< An opaque value that represents a CUDA surface object */ +typedef CUsurfObject_v1 CUsurfObject; /**< An opaque value that represents a CUDA surface object */ +typedef struct CUextMemory_st *CUexternalMemory; /**< CUDA external memory */ +typedef struct CUextSemaphore_st *CUexternalSemaphore; /**< CUDA external semaphore */ +typedef struct CUgraph_st *CUgraph; /**< CUDA graph */ +typedef struct CUgraphNode_st *CUgraphNode; /**< CUDA graph node */ +typedef struct CUgraphExec_st *CUgraphExec; /**< CUDA executable graph */ +typedef struct CUmemPoolHandle_st *CUmemoryPool; /**< CUDA memory pool */ +typedef struct CUuserObject_st *CUuserObject; /**< CUDA user object for graphs */ #ifndef CU_UUID_HAS_BEEN_DEFINED #define CU_UUID_HAS_BEEN_DEFINED @@ -275,8 +272,6 @@ typedef struct CUuuid_st { /**< CUDA definition o } CUuuid; #endif -#if __CUDA_API_VERSION >= 4010 - /** * CUDA IPC handle size */ @@ -287,14 +282,16 @@ typedef struct CUuuid_st { /**< CUDA definition o */ typedef struct CUipcEventHandle_st { char reserved[CU_IPC_HANDLE_SIZE]; -} CUipcEventHandle; +} CUipcEventHandle_v1; +typedef CUipcEventHandle_v1 CUipcEventHandle; /** * CUDA IPC mem handle */ typedef struct CUipcMemHandle_st { char reserved[CU_IPC_HANDLE_SIZE]; -} CUipcMemHandle; +} CUipcMemHandle_v1; +typedef CUipcMemHandle_v1 CUipcMemHandle; /** * CUDA Ipc Mem Flags @@ -303,7 +300,6 @@ typedef enum CUipcMem_flags_enum { CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS = 0x1 /**< Automatically enable peer access between remote devices as needed */ } CUipcMem_flags; -#endif /** * CUDA Mem Attach Flags @@ -326,7 +322,9 @@ typedef enum CUctx_flags_enum { * \deprecated This flag was deprecated as of CUDA 4.0 * and was replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. */ CU_CTX_SCHED_MASK = 0x07, - CU_CTX_MAP_HOST = 0x08, /**< Support mapped pinned allocations */ + CU_CTX_MAP_HOST = 0x08, /**< \deprecated This flag was deprecated as of CUDA 11.0 + * and it no longer has any effect. All contexts + * as of CUDA 3.2 behave as though the flag is enabled. */ CU_CTX_LMEM_RESIZE_TO_MAX = 0x10, /**< Keep local memory allocation after launch */ CU_CTX_FLAGS_MASK = 0x1f } CUctx_flags; @@ -335,8 +333,8 @@ typedef enum CUctx_flags_enum { * Stream creation flags */ typedef enum CUstream_flags_enum { - CU_STREAM_DEFAULT = 0x0, /**< Default stream flag */ - CU_STREAM_NON_BLOCKING = 0x1 /**< Stream does not synchronize with stream 0 (the NULL stream) */ + CU_STREAM_DEFAULT = 0x0, /**< Default stream flag */ + CU_STREAM_NON_BLOCKING = 0x1 /**< Stream does not synchronize with stream 0 (the NULL stream) */ } CUstream_flags; /** @@ -369,7 +367,26 @@ typedef enum CUevent_flags_enum { CU_EVENT_INTERPROCESS = 0x4 /**< Event is suitable for interprocess use. CU_EVENT_DISABLE_TIMING must be set */ } CUevent_flags; -#if __CUDA_API_VERSION >= 8000 +/** + * Event record flags + */ +typedef enum CUevent_record_flags_enum { + CU_EVENT_RECORD_DEFAULT = 0x0, /**< Default event record flag */ + CU_EVENT_RECORD_EXTERNAL = 0x1 /**< When using stream capture, create an event record node + * instead of the default behavior. This flag is invalid + * when used outside of capture. */ +} CUevent_record_flags; + +/** + * Event wait flags + */ +typedef enum CUevent_wait_flags_enum { + CU_EVENT_WAIT_DEFAULT = 0x0, /**< Default event wait flag */ + CU_EVENT_WAIT_EXTERNAL = 0x1 /**< When using stream capture, create an event wait node + * instead of the default behavior. This flag is invalid + * when used outside of capture.*/ +} CUevent_wait_flags; + /** * Flags for ::cuStreamWaitValue32 and ::cuStreamWaitValue64 */ @@ -448,8 +465,8 @@ typedef union CUstreamBatchMemOpParams_union { unsigned int flags; } flushRemoteWrites; cuuint64_t pad[6]; -} CUstreamBatchMemOpParams; -#endif /* __CUDA_API_VERSION >= 8000 */ +} CUstreamBatchMemOpParams_v1; +typedef CUstreamBatchMemOpParams_v1 CUstreamBatchMemOpParams; /** * Occupancy calculator flag @@ -459,6 +476,14 @@ typedef enum CUoccupancy_flags_enum { CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE = 0x1 /**< Assume global caching is enabled and cannot be automatically turned off */ } CUoccupancy_flags; +/** + * Flags for ::cuStreamUpdateCaptureDependencies + */ +typedef enum CUstreamUpdateCaptureDependencies_flags_enum { + CU_STREAM_ADD_CAPTURE_DEPENDENCIES = 0x0, /**< Add new nodes to the dependency set */ + CU_STREAM_SET_CAPTURE_DEPENDENCIES = 0x1 /**< Replace the dependency set with the new nodes */ +} CUstreamUpdateCaptureDependencies_flags; + /** * Array formats */ @@ -470,7 +495,34 @@ typedef enum CUarray_format_enum { CU_AD_FORMAT_SIGNED_INT16 = 0x09, /**< Signed 16-bit integers */ CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */ CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */ - CU_AD_FORMAT_FLOAT = 0x20 /**< 32-bit floating point */ + CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */ + CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */ + CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */ + CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */ + CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */ + CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */ + CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */ + CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */ + CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */ } CUarray_format; /** @@ -495,112 +547,131 @@ typedef enum CUfilter_mode_enum { * Device properties */ typedef enum CUdevice_attribute_enum { - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1, /**< Maximum number of threads per block */ - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2, /**< Maximum block dimension X */ - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y = 3, /**< Maximum block dimension Y */ - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z = 4, /**< Maximum block dimension Z */ - CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X = 5, /**< Maximum grid dimension X */ - CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y = 6, /**< Maximum grid dimension Y */ - CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z = 7, /**< Maximum grid dimension Z */ - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8, /**< Maximum shared memory available per block in bytes */ - CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK = 8, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK */ - CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY = 9, /**< Memory available on device for __constant__ variables in a CUDA C kernel in bytes */ - CU_DEVICE_ATTRIBUTE_WARP_SIZE = 10, /**< Warp size in threads */ - CU_DEVICE_ATTRIBUTE_MAX_PITCH = 11, /**< Maximum pitch in bytes allowed by memory copies */ - CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12, /**< Maximum number of 32-bit registers available per block */ - CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK = 12, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK */ - CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13, /**< Typical clock frequency in kilohertz */ - CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT = 14, /**< Alignment requirement for textures */ - CU_DEVICE_ATTRIBUTE_GPU_OVERLAP = 15, /**< Device can possibly copy memory and execute a kernel concurrently. Deprecated. Use instead CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT. */ - CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16, /**< Number of multiprocessors on device */ - CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT = 17, /**< Specifies whether there is a run time limit on kernels */ - CU_DEVICE_ATTRIBUTE_INTEGRATED = 18, /**< Device is integrated with host memory */ - CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY = 19, /**< Device can map host memory into CUDA address space */ - CU_DEVICE_ATTRIBUTE_COMPUTE_MODE = 20, /**< Compute mode (See ::CUcomputemode for details) */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH = 21, /**< Maximum 1D texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH = 22, /**< Maximum 2D texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT = 23, /**< Maximum 2D texture height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH = 24, /**< Maximum 3D texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT = 25, /**< Maximum 3D texture height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH = 26, /**< Maximum 3D texture depth */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH = 27, /**< Maximum 2D layered texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT = 28, /**< Maximum 2D layered texture height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS = 29, /**< Maximum layers in a 2D layered texture */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH = 27, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT = 28, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES = 29, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS */ - CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT = 30, /**< Alignment requirement for surfaces */ - CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS = 31, /**< Device can possibly execute multiple kernels concurrently */ - CU_DEVICE_ATTRIBUTE_ECC_ENABLED = 32, /**< Device has ECC support enabled */ - CU_DEVICE_ATTRIBUTE_PCI_BUS_ID = 33, /**< PCI bus ID of the device */ - CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID = 34, /**< PCI device ID of the device */ - CU_DEVICE_ATTRIBUTE_TCC_DRIVER = 35, /**< Device is using TCC driver model */ - CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36, /**< Peak memory clock frequency in kilohertz */ - CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37, /**< Global memory bus width in bits */ - CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38, /**< Size of L2 cache in bytes */ - CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39, /**< Maximum resident threads per multiprocessor */ - CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT = 40, /**< Number of asynchronous engines */ - CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41, /**< Device shares a unified address space with the host */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH = 42, /**< Maximum 1D layered texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS = 43, /**< Maximum layers in a 1D layered texture */ - CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER = 44, /**< Deprecated, do not use. */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH = 45, /**< Maximum 2D texture width if CUDA_ARRAY3D_TEXTURE_GATHER is set */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT = 46, /**< Maximum 2D texture height if CUDA_ARRAY3D_TEXTURE_GATHER is set */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE = 47, /**< Alternate maximum 3D texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE = 48,/**< Alternate maximum 3D texture height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE = 49, /**< Alternate maximum 3D texture depth */ - CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID = 50, /**< PCI domain ID of the device */ - CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT = 51, /**< Pitch alignment requirement for textures */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH = 52, /**< Maximum cubemap texture width/height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH = 53, /**< Maximum cubemap layered texture width/height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS = 54, /**< Maximum layers in a cubemap layered texture */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH = 55, /**< Maximum 1D surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH = 56, /**< Maximum 2D surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT = 57, /**< Maximum 2D surface height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH = 58, /**< Maximum 3D surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT = 59, /**< Maximum 3D surface height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH = 60, /**< Maximum 3D surface depth */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH = 61, /**< Maximum 1D layered surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS = 62, /**< Maximum layers in a 1D layered surface */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH = 63, /**< Maximum 2D layered surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT = 64, /**< Maximum 2D layered surface height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS = 65, /**< Maximum layers in a 2D layered surface */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH = 66, /**< Maximum cubemap surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH = 67, /**< Maximum cubemap layered surface width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS = 68, /**< Maximum layers in a cubemap layered surface */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH = 69, /**< Maximum 1D linear texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH = 70, /**< Maximum 2D linear texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT = 71, /**< Maximum 2D linear texture height */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH = 72, /**< Maximum 2D linear texture pitch in bytes */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH = 73, /**< Maximum mipmapped 2D texture width */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT = 74,/**< Maximum mipmapped 2D texture height */ - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, /**< Major compute capability version number */ - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, /**< Minor compute capability version number */ - CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH = 77, /**< Maximum mipmapped 1D texture width */ - CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED = 78, /**< Device supports stream priorities */ - CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED = 79, /**< Device supports caching globals in L1 */ - CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED = 80, /**< Device supports caching locals in L1 */ - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = 81, /**< Maximum shared memory available per multiprocessor in bytes */ - CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR = 82, /**< Maximum number of 32-bit registers available per multiprocessor */ - CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY = 83, /**< Device can allocate managed memory on this system */ - CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD = 84, /**< Device is on a multi-GPU board */ - CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID = 85, /**< Unique id for a group of devices on the same multi-GPU board */ - CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED = 86, /**< Link between the device and the host supports native atomic operations (this is a placeholder attribute, and is not supported on any current hardware)*/ - CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO = 87, /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ - CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS = 88, /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ - CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = 89, /**< Device can coherently access managed memory concurrently with the CPU */ - CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = 90, /**< Device supports compute preemption. */ - CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = 91, /**< Device can access host registered memory at the same virtual address as the CPU */ - CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS = 92, /**< ::cuStreamBatchMemOp and related APIs are supported. */ - CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 93, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and related APIs. */ - CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 94, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported. */ - CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = 95, /**< Device supports launching cooperative kernels via ::cuLaunchCooperativeKernel */ - CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH = 96, /**< Device can participate in cooperative kernels launched via ::cuLaunchCooperativeKernelMultiDevice */ - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97, /**< Maximum optin shared memory per block */ - CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES = 98, /**< Both the ::CU_STREAM_WAIT_VALUE_FLUSH flag and the ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the device. See \ref CUDA_MEMOP for additional details. */ - CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED = 99, /**< Device supports host memory registration via ::cudaHostRegister. */ + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1, /**< Maximum number of threads per block */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2, /**< Maximum block dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y = 3, /**< Maximum block dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z = 4, /**< Maximum block dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X = 5, /**< Maximum grid dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y = 6, /**< Maximum grid dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z = 7, /**< Maximum grid dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8, /**< Maximum shared memory available per block in bytes */ + CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK = 8, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY = 9, /**< Memory available on device for __constant__ variables in a CUDA C kernel in bytes */ + CU_DEVICE_ATTRIBUTE_WARP_SIZE = 10, /**< Warp size in threads */ + CU_DEVICE_ATTRIBUTE_MAX_PITCH = 11, /**< Maximum pitch in bytes allowed by memory copies */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12, /**< Maximum number of 32-bit registers available per block */ + CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK = 12, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13, /**< Typical clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT = 14, /**< Alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_GPU_OVERLAP = 15, /**< Device can possibly copy memory and execute a kernel concurrently. Deprecated. Use instead CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT. */ + CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16, /**< Number of multiprocessors on device */ + CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT = 17, /**< Specifies whether there is a run time limit on kernels */ + CU_DEVICE_ATTRIBUTE_INTEGRATED = 18, /**< Device is integrated with host memory */ + CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY = 19, /**< Device can map host memory into CUDA address space */ + CU_DEVICE_ATTRIBUTE_COMPUTE_MODE = 20, /**< Compute mode (See ::CUcomputemode for details) */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH = 21, /**< Maximum 1D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH = 22, /**< Maximum 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT = 23, /**< Maximum 2D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH = 24, /**< Maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT = 25, /**< Maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH = 26, /**< Maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH = 27, /**< Maximum 2D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT = 28, /**< Maximum 2D layered texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS = 29, /**< Maximum layers in a 2D layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH = 27, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT = 28, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES = 29, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS */ + CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT = 30, /**< Alignment requirement for surfaces */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS = 31, /**< Device can possibly execute multiple kernels concurrently */ + CU_DEVICE_ATTRIBUTE_ECC_ENABLED = 32, /**< Device has ECC support enabled */ + CU_DEVICE_ATTRIBUTE_PCI_BUS_ID = 33, /**< PCI bus ID of the device */ + CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID = 34, /**< PCI device ID of the device */ + CU_DEVICE_ATTRIBUTE_TCC_DRIVER = 35, /**< Device is using TCC driver model */ + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36, /**< Peak memory clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37, /**< Global memory bus width in bits */ + CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38, /**< Size of L2 cache in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39, /**< Maximum resident threads per multiprocessor */ + CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT = 40, /**< Number of asynchronous engines */ + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41, /**< Device shares a unified address space with the host */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH = 42, /**< Maximum 1D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS = 43, /**< Maximum layers in a 1D layered texture */ + CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER = 44, /**< Deprecated, do not use. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH = 45, /**< Maximum 2D texture width if CUDA_ARRAY3D_TEXTURE_GATHER is set */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT = 46, /**< Maximum 2D texture height if CUDA_ARRAY3D_TEXTURE_GATHER is set */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE = 47, /**< Alternate maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE = 48, /**< Alternate maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE = 49, /**< Alternate maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID = 50, /**< PCI domain ID of the device */ + CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT = 51, /**< Pitch alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH = 52, /**< Maximum cubemap texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH = 53, /**< Maximum cubemap layered texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS = 54, /**< Maximum layers in a cubemap layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH = 55, /**< Maximum 1D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH = 56, /**< Maximum 2D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT = 57, /**< Maximum 2D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH = 58, /**< Maximum 3D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT = 59, /**< Maximum 3D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH = 60, /**< Maximum 3D surface depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH = 61, /**< Maximum 1D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS = 62, /**< Maximum layers in a 1D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH = 63, /**< Maximum 2D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT = 64, /**< Maximum 2D layered surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS = 65, /**< Maximum layers in a 2D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH = 66, /**< Maximum cubemap surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH = 67, /**< Maximum cubemap layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS = 68, /**< Maximum layers in a cubemap layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH = 69, /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH = 70, /**< Maximum 2D linear texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT = 71, /**< Maximum 2D linear texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH = 72, /**< Maximum 2D linear texture pitch in bytes */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH = 73, /**< Maximum mipmapped 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT = 74, /**< Maximum mipmapped 2D texture height */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, /**< Major compute capability version number */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, /**< Minor compute capability version number */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH = 77, /**< Maximum mipmapped 1D texture width */ + CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED = 78, /**< Device supports stream priorities */ + CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED = 79, /**< Device supports caching globals in L1 */ + CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED = 80, /**< Device supports caching locals in L1 */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = 81, /**< Maximum shared memory available per multiprocessor in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR = 82, /**< Maximum number of 32-bit registers available per multiprocessor */ + CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY = 83, /**< Device can allocate managed memory on this system */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD = 84, /**< Device is on a multi-GPU board */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID = 85, /**< Unique id for a group of devices on the same multi-GPU board */ + CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED = 86, /**< Link between the device and the host supports native atomic operations (this is a placeholder attribute, and is not supported on any current hardware)*/ + CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO = 87, /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ + CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS = 88, /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = 89, /**< Device can coherently access managed memory concurrently with the CPU */ + CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = 90, /**< Device supports compute preemption. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = 91, /**< Device can access host registered memory at the same virtual address as the CPU */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS = 92, /**< ::cuStreamBatchMemOp and related APIs are supported. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 93, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and related APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 94, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported. */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = 95, /**< Device supports launching cooperative kernels via ::cuLaunchCooperativeKernel */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH = 96, /**< Deprecated, ::cuLaunchCooperativeKernelMultiDevice is deprecated. */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97, /**< Maximum optin shared memory per block */ + CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES = 98, /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the device. See \ref CUDA_MEMOP for additional details. */ + CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED = 99, /**< Device supports host memory registration via ::cudaHostRegister. */ CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES = 100, /**< Device accesses pageable memory via the host's page tables. */ - CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST = 101, /**< The host can directly access managed memory on the device without migration. */ + CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST = 101, /**< The host can directly access managed memory on the device without migration. */ + CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED = 102, /**< Deprecated, Use CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED*/ + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */ + CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */ + CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */ + CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE = 109, /**< Maximum value of CUaccessPolicyWindow::num_bytes. */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */ + CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */ + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ + CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */ + CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS = 117, /**< The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING = 118, /**< GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. */ + CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES = 119, /**< Handle types supported with mempool based IPC */ CU_DEVICE_ATTRIBUTE_MAX } CUdevice_attribute; @@ -618,21 +689,30 @@ typedef struct CUdevprop_st { int regsPerBlock; /**< 32-bit registers available per block */ int clockRate; /**< Clock frequency in kilohertz */ int textureAlign; /**< Alignment requirement for textures */ -} CUdevprop; +} CUdevprop_v1; +typedef CUdevprop_v1 CUdevprop; /** * Pointer information */ typedef enum CUpointer_attribute_enum { - CU_POINTER_ATTRIBUTE_CONTEXT = 1, /**< The ::CUcontext on which a pointer was allocated or registered */ - CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2, /**< The ::CUmemorytype describing the physical location of a pointer */ - CU_POINTER_ATTRIBUTE_DEVICE_POINTER = 3, /**< The address at which a pointer's memory may be accessed on the device */ - CU_POINTER_ATTRIBUTE_HOST_POINTER = 4, /**< The address at which a pointer's memory may be accessed on the host */ - CU_POINTER_ATTRIBUTE_P2P_TOKENS = 5, /**< A pair of tokens for use with the nv-p2p.h Linux kernel interface */ - CU_POINTER_ATTRIBUTE_SYNC_MEMOPS = 6, /**< Synchronize every synchronous memory operation initiated on this region */ - CU_POINTER_ATTRIBUTE_BUFFER_ID = 7, /**< A process-wide unique ID for an allocated memory region*/ - CU_POINTER_ATTRIBUTE_IS_MANAGED = 8, /**< Indicates if the pointer points to managed memory */ - CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = 9 /**< A device ordinal of a device on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_CONTEXT = 1, /**< The ::CUcontext on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2, /**< The ::CUmemorytype describing the physical location of a pointer */ + CU_POINTER_ATTRIBUTE_DEVICE_POINTER = 3, /**< The address at which a pointer's memory may be accessed on the device */ + CU_POINTER_ATTRIBUTE_HOST_POINTER = 4, /**< The address at which a pointer's memory may be accessed on the host */ + CU_POINTER_ATTRIBUTE_P2P_TOKENS = 5, /**< A pair of tokens for use with the nv-p2p.h Linux kernel interface */ + CU_POINTER_ATTRIBUTE_SYNC_MEMOPS = 6, /**< Synchronize every synchronous memory operation initiated on this region */ + CU_POINTER_ATTRIBUTE_BUFFER_ID = 7, /**< A process-wide unique ID for an allocated memory region*/ + CU_POINTER_ATTRIBUTE_IS_MANAGED = 8, /**< Indicates if the pointer points to managed memory */ + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = 9, /**< A device ordinal of a device on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE = 10, /**< 1 if this pointer maps to an allocation that is suitable for ::cudaIpcGetMemHandle, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR = 11, /**< Starting address for this requested pointer */ + CU_POINTER_ATTRIBUTE_RANGE_SIZE = 12, /**< Size of the address range for this requested pointer */ + CU_POINTER_ATTRIBUTE_MAPPED = 13, /**< 1 if this pointer is in a valid address range that is mapped to a backing allocation, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES = 14, /**< Bitmask of allowed ::CUmemAllocationHandleType for this allocation **/ + CU_POINTER_ATTRIBUTE_IS_GPU_DIRECT_RDMA_CAPABLE = 15, /**< 1 if the memory this pointer is referencing can be used with the GPUDirect RDMA API **/ + CU_POINTER_ATTRIBUTE_ACCESS_FLAGS = 16, /**< Returns the access flags the device associated with the current context has on the corresponding memory referenced by the pointer given */ + CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE = 17 /**< Returns the mempool handle for the allocation if it was allocated from a mempool. Otherwise returns NULL. **/ } CUpointer_attribute; /** @@ -697,13 +777,16 @@ typedef enum CUfunction_attribute_enum { * The maximum size in bytes of dynamically-allocated shared memory that can be used by * this function. If the user-specified dynamic shared memory size is larger than this * value, the launch will fail. + * See ::cuFuncSetAttribute */ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, /** - * On devices where the L1 cache and shared memory use the same hardware resources, - * this sets the shared memory carveout preference, in percent of the total resources. + * On devices where the L1 cache and shared memory use the same hardware resources, + * this sets the shared memory carveout preference, in percent of the total shared memory. + * Refer to ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR. * This is only a hint, and the driver can choose a different ratio if required to execute the function. + * See ::cuFuncSetAttribute */ CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT = 9, @@ -730,12 +813,12 @@ typedef enum CUsharedconfig_enum { } CUsharedconfig; /** - * Shared memory carveout configurations + * Shared memory carveout configurations. These may be passed to ::cuFuncSetAttribute */ typedef enum CUshared_carveout_enum { - CU_SHAREDMEM_CARVEOUT_DEFAULT = -1, /** < no preference for shared memory or L1 (default) */ - CU_SHAREDMEM_CARVEOUT_MAX_SHARED = 100, /** < prefer maximum available shared memory, minimum L1 cache */ - CU_SHAREDMEM_CARVEOUT_MAX_L1 = 0 /** < prefer maximum available L1 cache, minimum shared memory */ + CU_SHAREDMEM_CARVEOUT_DEFAULT = -1, /**< No preference for shared memory or L1 (default) */ + CU_SHAREDMEM_CARVEOUT_MAX_SHARED = 100, /**< Prefer maximum available shared memory, minimum L1 cache */ + CU_SHAREDMEM_CARVEOUT_MAX_L1 = 0 /**< Prefer maximum available L1 cache, minimum shared memory */ } CUshared_carveout; /** @@ -947,6 +1030,51 @@ typedef enum CUjit_option_enum */ CU_JIT_GLOBAL_SYMBOL_COUNT, + /** + * Enable link-time optimization (-dlto) for device code (0: false, default)\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_LTO, + + /** + * Control single-precision denormals (-ftz) support (0: false, default). + * 1 : flushes denormal values to zero + * 0 : preserves denormal values + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_FTZ, + + /** + * Control single-precision floating-point division and reciprocals + * (-prec-div) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_PREC_DIV, + + /** + * Control single-precision floating-point square root + * (-prec-sqrt) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_PREC_SQRT, + + /** + * Enable/Disable the contraction of floating-point multiplies + * and adds/subtracts into floating-point multiply-add (-fma) + * operations (1: Enable, default; 0: Disable). + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_FMA, + CU_JIT_NUM_OPTIONS } CUjit_option; @@ -969,8 +1097,10 @@ typedef enum CUjit_target_enum CU_TARGET_COMPUTE_61 = 61, /**< Compute device class 6.1.*/ CU_TARGET_COMPUTE_62 = 62, /**< Compute device class 6.2.*/ CU_TARGET_COMPUTE_70 = 70, /**< Compute device class 7.0.*/ - - CU_TARGET_COMPUTE_75 = 75 /**< Compute device class 7.5.*/ + CU_TARGET_COMPUTE_72 = 72, /**< Compute device class 7.2.*/ + CU_TARGET_COMPUTE_75 = 75, /**< Compute device class 7.5.*/ + CU_TARGET_COMPUTE_80 = 80, /**< Compute device class 8.0.*/ + CU_TARGET_COMPUTE_86 = 86 /**< Compute device class 8.6.*/ } CUjit_target; /** @@ -1029,12 +1159,16 @@ typedef enum CUjitInputType_enum */ CU_JIT_INPUT_LIBRARY, + /** + * High-level intermediate code for link-time optimization\n + * Applicable options: NVVM compiler options, PTX compiler options + */ + CU_JIT_INPUT_NVVM, + CU_JIT_NUM_INPUT_TYPES } CUjitInputType; -#if __CUDA_API_VERSION >= 5050 typedef struct CUlinkState_st *CUlinkState; -#endif /* __CUDA_API_VERSION >= 5050 */ /** * Flags to register a graphics resource @@ -1078,6 +1212,7 @@ typedef enum CUlimit_enum { CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH = 0x03, /**< GPU device runtime launch synchronize depth */ CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT = 0x04, /**< GPU device runtime pending launch count */ CU_LIMIT_MAX_L2_FETCH_GRANULARITY = 0x05, /**< A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint */ + CU_LIMIT_PERSISTING_L2_CACHE_SIZE = 0x06, /**< A size in bytes for L2 persisting lines cache size */ CU_LIMIT_MAX } CUlimit; @@ -1097,14 +1232,42 @@ typedef enum CUresourcetype_enum { #define CUDA_CB #endif -#if __CUDA_API_VERSION >= 10000 - /** * CUDA host function * \param userData Argument value passed to the function */ typedef void (CUDA_CB *CUhostFn)(void *userData); +/** + * Specifies performance hint with ::CUaccessPolicyWindow for hitProp and missProp members. + */ +typedef enum CUaccessProperty_enum { + CU_ACCESS_PROPERTY_NORMAL = 0, /**< Normal cache persistence. */ + CU_ACCESS_PROPERTY_STREAMING = 1, /**< Streaming access is less likely to persit from cache. */ + CU_ACCESS_PROPERTY_PERSISTING = 2 /**< Persisting access is more likely to persist in cache.*/ +} CUaccessProperty; + +/** + * Specifies an access policy for a window, a contiguous extent of memory + * beginning at base_ptr and ending at base_ptr + num_bytes. + * num_bytes is limited by CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE. + * Partition into many segments and assign segments such that: + * sum of "hit segments" / window == approx. ratio. + * sum of "miss segments" / window == approx 1-ratio. + * Segments and ratio specifications are fitted to the capabilities of + * the architecture. + * Accesses in a hit segment apply the hitProp access policy. + * Accesses in a miss segment apply the missProp access policy. + */ +typedef struct CUaccessPolicyWindow_st { + void *base_ptr; /**< Starting address of the access policy window. CUDA driver may align it. */ + size_t num_bytes; /**< Size in bytes of the window policy. CUDA driver may restrict the maximum size and alignment. */ + float hitRatio; /**< hitRatio specifies percentage of lines assigned hitProp, rest are assigned missProp. */ + CUaccessProperty hitProp; /**< ::CUaccessProperty set for hit. */ + CUaccessProperty missProp; /**< ::CUaccessProperty set for miss. Must be either NORMAL or STREAMING */ +} CUaccessPolicyWindow_v1; +typedef CUaccessPolicyWindow_v1 CUaccessPolicyWindow; + /** * GPU kernel node parameters */ @@ -1119,7 +1282,8 @@ typedef struct CUDA_KERNEL_NODE_PARAMS_st { unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ void **kernelParams; /**< Array of pointers to kernel parameters */ void **extra; /**< Extra options */ -} CUDA_KERNEL_NODE_PARAMS; +} CUDA_KERNEL_NODE_PARAMS_v1; +typedef CUDA_KERNEL_NODE_PARAMS_v1 CUDA_KERNEL_NODE_PARAMS; /** * Memset node parameters @@ -1129,9 +1293,10 @@ typedef struct CUDA_MEMSET_NODE_PARAMS_st { size_t pitch; /**< Pitch of destination device pointer. Unused if height is 1 */ unsigned int value; /**< Value to be set */ unsigned int elementSize; /**< Size of each element in bytes. Must be 1, 2, or 4. */ - size_t width; /**< Width in bytes, of the row */ + size_t width; /**< Width of the row in elements */ size_t height; /**< Number of rows */ -} CUDA_MEMSET_NODE_PARAMS; +} CUDA_MEMSET_NODE_PARAMS_v1; +typedef CUDA_MEMSET_NODE_PARAMS_v1 CUDA_MEMSET_NODE_PARAMS; /** * Host node parameters @@ -1139,21 +1304,51 @@ typedef struct CUDA_MEMSET_NODE_PARAMS_st { typedef struct CUDA_HOST_NODE_PARAMS_st { CUhostFn fn; /**< The function to call when the node executes */ void* userData; /**< Argument to pass to the function */ -} CUDA_HOST_NODE_PARAMS; +} CUDA_HOST_NODE_PARAMS_v1; +typedef CUDA_HOST_NODE_PARAMS_v1 CUDA_HOST_NODE_PARAMS; /** * Graph node types */ typedef enum CUgraphNodeType_enum { - CU_GRAPH_NODE_TYPE_KERNEL = 0, /**< GPU kernel node */ - CU_GRAPH_NODE_TYPE_MEMCPY = 1, /**< Memcpy node */ - CU_GRAPH_NODE_TYPE_MEMSET = 2, /**< Memset node */ - CU_GRAPH_NODE_TYPE_HOST = 3, /**< Host (executable) node */ - CU_GRAPH_NODE_TYPE_GRAPH = 4, /**< Node which executes an embedded graph */ - CU_GRAPH_NODE_TYPE_EMPTY = 5, /**< Empty (no-op) node */ - CU_GRAPH_NODE_TYPE_COUNT + CU_GRAPH_NODE_TYPE_KERNEL = 0, /**< GPU kernel node */ + CU_GRAPH_NODE_TYPE_MEMCPY = 1, /**< Memcpy node */ + CU_GRAPH_NODE_TYPE_MEMSET = 2, /**< Memset node */ + CU_GRAPH_NODE_TYPE_HOST = 3, /**< Host (executable) node */ + CU_GRAPH_NODE_TYPE_GRAPH = 4, /**< Node which executes an embedded graph */ + CU_GRAPH_NODE_TYPE_EMPTY = 5, /**< Empty (no-op) node */ + CU_GRAPH_NODE_TYPE_WAIT_EVENT = 6, /**< External event wait node */ + CU_GRAPH_NODE_TYPE_EVENT_RECORD = 7, /**< External event record node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL = 8, /**< External semaphore signal node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT = 9, /**< External semaphore wait node */ + CU_GRAPH_NODE_TYPE_MEM_ALLOC = 10,/**< Memory Allocation Node */ + CU_GRAPH_NODE_TYPE_MEM_FREE = 11 /**< Memory Free Node */ } CUgraphNodeType; +typedef enum CUsynchronizationPolicy_enum { + CU_SYNC_POLICY_AUTO = 1, + CU_SYNC_POLICY_SPIN = 2, + CU_SYNC_POLICY_YIELD = 3, + CU_SYNC_POLICY_BLOCKING_SYNC = 4 +} CUsynchronizationPolicy; + +/** + * Graph kernel node Attributes + */ +typedef enum CUkernelNodeAttrID_enum { + CU_KERNEL_NODE_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUkernelNodeAttrValue::accessPolicyWindow. */ + CU_KERNEL_NODE_ATTRIBUTE_COOPERATIVE = 2 /**< Allows a kernel node to be cooperative (see ::cuLaunchCooperativeKernel). */ +} CUkernelNodeAttrID; + +/** + * Graph kernel node attributes union, used with ::cuKernelNodeSetAttribute/::cuKernelNodeGetAttribute + */ +typedef union CUkernelNodeAttrValue_union { + CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ + int cooperative; /**< Nonzero indicates a cooperative kernel (see ::cuLaunchCooperativeKernel). */ +} CUkernelNodeAttrValue_v1; +typedef CUkernelNodeAttrValue_v1 CUkernelNodeAttrValue; + /** * Possible stream capture statuses returned by ::cuStreamIsCapturing */ @@ -1164,7 +1359,68 @@ typedef enum CUstreamCaptureStatus_enum { has been invalidated, but not terminated */ } CUstreamCaptureStatus; -#endif /* __CUDA_API_VERSION >= 10000 */ +/** + * Possible modes for stream capture thread interactions. For more details see + * ::cuStreamBeginCapture and ::cuThreadExchangeStreamCaptureMode + */ +typedef enum CUstreamCaptureMode_enum { + CU_STREAM_CAPTURE_MODE_GLOBAL = 0, + CU_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + CU_STREAM_CAPTURE_MODE_RELAXED = 2 +} CUstreamCaptureMode; + +/** + * Stream Attributes + */ +typedef enum CUstreamAttrID_enum { + CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUstreamAttrValue::accessPolicyWindow. */ + CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3 /**< ::CUsynchronizationPolicy for work queued up in this stream */ +} CUstreamAttrID; + +/** + * Stream attributes union, used with ::cuStreamSetAttribute/::cuStreamGetAttribute + */ +typedef union CUstreamAttrValue_union { + CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ + CUsynchronizationPolicy syncPolicy; /**< Value for ::CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY. */ +} CUstreamAttrValue_v1; +typedef CUstreamAttrValue_v1 CUstreamAttrValue; + +/** + * Flags to specify search options. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddress_flags_enum { + CU_GET_PROC_ADDRESS_DEFAULT = 0, /**< Default search mode for driver symbols. */ + CU_GET_PROC_ADDRESS_LEGACY_STREAM = 1 << 0, /**< Search for legacy versions of driver symbols. */ + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM = 1 << 1 /**< Search for per-thread versions of driver symbols. */ +} CUdriverProcAddress_flags; + +/** + * Execution Affinity Types + */ +typedef enum CUexecAffinityType_enum { + CU_EXEC_AFFINITY_TYPE_SM_COUNT = 0, /**< Create a context with limited SMs. */ + CU_EXEC_AFFINITY_TYPE_MAX +} CUexecAffinityType; + +/** + * Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT + */ +typedef struct CUexecAffinitySmCount_st { + unsigned int val; /**< The number of SMs the context is limited to use. */ +} CUexecAffinitySmCount_v1; +typedef CUexecAffinitySmCount_v1 CUexecAffinitySmCount; + +/** + * Execution Affinity Parameters + */ +typedef struct CUexecAffinityParam_st { + CUexecAffinityType type; + union { + CUexecAffinitySmCount smCount; /** Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT */ + } param; +} CUexecAffinityParam_v1; +typedef CUexecAffinityParam_v1 CUexecAffinityParam; /** * Error codes @@ -1229,6 +1485,13 @@ typedef enum cudaError_enum { */ CUDA_ERROR_PROFILER_ALREADY_STOPPED = 8, + /** + * This indicates that the CUDA driver that the application has loaded is a + * stub library. Applications that run with the stub rather than a real + * driver loaded will result in CUDA API returning this error. + */ + CUDA_ERROR_STUB_LIBRARY = 34, + /** * This indicates that no CUDA-capable devices were detected by the installed * CUDA driver. @@ -1237,10 +1500,15 @@ typedef enum cudaError_enum { /** * This indicates that the device ordinal supplied by the user does not - * correspond to a valid CUDA device. + * correspond to a valid CUDA device or that the action requested is + * invalid for the specified device. */ CUDA_ERROR_INVALID_DEVICE = 101, + /** + * This error indicates that the Grid license is not applied. + */ + CUDA_ERROR_DEVICE_NOT_LICENSED = 102, /** * This indicates that the device kernel image is invalid. This can also @@ -1365,7 +1633,25 @@ typedef enum cudaError_enum { CUDA_ERROR_JIT_COMPILER_NOT_FOUND = 221, /** - * This indicates that the device kernel source is invalid. + * This indicates that the provided PTX was compiled with an unsupported toolchain. + */ + + CUDA_ERROR_UNSUPPORTED_PTX_VERSION = 222, + + /** + * This indicates that the PTX JIT compilation was disabled. + */ + CUDA_ERROR_JIT_COMPILATION_DISABLED = 223, + + /** + * This indicates that the ::CUexecAffinityType passed to the API call is not + * supported by the active device. + */ + CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, + + /** + * This indicates that the device kernel source is invalid. This includes + * compilation/linker errors encountered in device code or user error. */ CUDA_ERROR_INVALID_SOURCE = 300, @@ -1403,7 +1689,8 @@ typedef enum cudaError_enum { /** * This indicates that a named symbol was not found. Examples of symbols - * are global/constant variable names, texture names, and surface names. + * are global/constant variable names, driver function names, texture names, + * and surface names. */ CUDA_ERROR_NOT_FOUND = 500, @@ -1553,7 +1840,8 @@ typedef enum cudaError_enum { /** * An exception occurred on the device while executing a kernel. Common * causes include dereferencing an invalid device pointer and accessing - * out of bounds shared memory. + * out of bounds shared memory. Less common cases can be system specific - more + * information about these cases can be found in the system specific user guide. * This leaves the process in an inconsistent state and any further CUDA work * will return the same error. To continue using CUDA, the process must be terminated * and relaunched. @@ -1584,9 +1872,53 @@ typedef enum cudaError_enum { * This error indicates that the system is not yet ready to start any CUDA * work. To continue using CUDA, verify the system configuration is in a * valid state and all required driver daemons are actively running. + * More information about this error can be found in the system specific + * user guide. */ CUDA_ERROR_SYSTEM_NOT_READY = 802, + /** + * This error indicates that there is a mismatch between the versions of + * the display driver and the CUDA driver. Refer to the compatibility documentation + * for supported versions. + */ + CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, + + /** + * This error indicates that the system was upgraded to run with forward compatibility + * but the visible hardware detected by CUDA does not support this configuration. + * Refer to the compatibility documentation for the supported hardware matrix or ensure + * that only supported hardware is visible during initialization via the CUDA_VISIBLE_DEVICES + * environment variable. + */ + CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE = 804, + + /** + * This error indicates that the MPS client failed to connect to the MPS control daemon or the MPS server. + */ + CUDA_ERROR_MPS_CONNECTION_FAILED = 805, + + /** + * This error indicates that the remote procedural call between the MPS server and the MPS client failed. + */ + CUDA_ERROR_MPS_RPC_FAILURE = 806, + + /** + * This error indicates that the MPS server is not ready to accept new MPS client requests. + * This error can be returned when the MPS server is in the process of recovering from a fatal failure. + */ + CUDA_ERROR_MPS_SERVER_NOT_READY = 807, + + /** + * This error indicates that the hardware resources required to create MPS client have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CLIENTS_REACHED = 808, + + /** + * This error indicates the the hardware resources required to support device connections have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CONNECTIONS_REACHED = 809, + /** * This error indicates that the operation is not permitted when * the stream is capturing. @@ -1635,6 +1967,34 @@ typedef enum cudaError_enum { */ CUDA_ERROR_CAPTURED_EVENT = 907, + /** + * A stream capture sequence not initiated with the ::CU_STREAM_CAPTURE_MODE_RELAXED + * argument to ::cuStreamBeginCapture was passed to ::cuStreamEndCapture in a + * different thread. + */ + CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD = 908, + + /** + * This error indicates that the timeout specified for the wait operation has lapsed. + */ + CUDA_ERROR_TIMEOUT = 909, + + /** + * This error indicates that the graph update was not performed because it included + * changes which violated constraints specific to instantiated graph update. + */ + CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE = 910, + + /** + * This indicates that an async error has occurred in a device outside of CUDA. + * If CUDA was waiting for an external device's signal before consuming shared data, + * the external device signaled an error indicating that the data is not valid for + * consumption. This leaves the process in an inconsistent state and any further CUDA + * work will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_EXTERNAL_DEVICE = 911, + /** * This indicates that an unknown internal error has occurred. */ @@ -1648,7 +2008,7 @@ typedef enum CUdevice_P2PAttribute_enum { CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK = 0x01, /**< A relative value indicating the performance of the link between two devices */ CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED = 0x02, /**< P2P Access is enable */ CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED = 0x03, /**< Atomic operation over the link supported */ - CU_DEVICE_P2P_ATTRIBUTE_ARRAY_ACCESS_ACCESS_SUPPORTED = 0x04, /**< \deprecated use CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED instead */ + CU_DEVICE_P2P_ATTRIBUTE_ACCESS_ACCESS_SUPPORTED = 0x04, /**< \deprecated use CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED instead */ CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED = 0x04 /**< Accessing CUDA arrays over the link supported */ } CUdevice_P2PAttribute; @@ -1708,15 +2068,25 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); * On Windows the flag is a no-op. * On Linux that memory is marked as non cache-coherent for the GPU and * is expected to be physically contiguous. It may return - * CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, - * CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. - * On all other platforms, it is not supported and CUDA_ERROR_NOT_SUPPORTED + * ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, + * ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. + * On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED * is returned. * Flag for ::cuMemHostRegister() */ #define CU_MEMHOSTREGISTER_IOMEMORY 0x04 -#if __CUDA_API_VERSION >= 3020 +/** +* If set, the passed memory pointer is treated as pointing to memory that is +* considered read-only by the device. On platforms without +* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is +* required in order to register memory mapped to the CPU as read-only. Support +* for the use of this flag can be queried from the device attribute +* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with +* a current context associated with a device that does not have this attribute +* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED. +*/ +#define CU_MEMHOSTREGISTER_READ_ONLY 0x08 /** * 2D memory copy parameters @@ -1742,7 +2112,8 @@ typedef struct CUDA_MEMCPY2D_st { size_t WidthInBytes; /**< Width of 2D memory copy in bytes */ size_t Height; /**< Height of 2D memory copy */ -} CUDA_MEMCPY2D; +} CUDA_MEMCPY2D_v2; +typedef CUDA_MEMCPY2D_v2 CUDA_MEMCPY2D; /** * 3D memory copy parameters @@ -1775,7 +2146,8 @@ typedef struct CUDA_MEMCPY3D_st { size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ size_t Height; /**< Height of 3D memory copy */ size_t Depth; /**< Depth of 3D memory copy */ -} CUDA_MEMCPY3D; +} CUDA_MEMCPY3D_v2; +typedef CUDA_MEMCPY3D_v2 CUDA_MEMCPY3D; /** * 3D memory cross-context copy parameters @@ -1808,7 +2180,8 @@ typedef struct CUDA_MEMCPY3D_PEER_st { size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ size_t Height; /**< Height of 3D memory copy */ size_t Depth; /**< Depth of 3D memory copy */ -} CUDA_MEMCPY3D_PEER; +} CUDA_MEMCPY3D_PEER_v1; +typedef CUDA_MEMCPY3D_PEER_v1 CUDA_MEMCPY3D_PEER; /** * Array descriptor @@ -1820,7 +2193,8 @@ typedef struct CUDA_ARRAY_DESCRIPTOR_st CUarray_format Format; /**< Array format */ unsigned int NumChannels; /**< Channels per array element */ -} CUDA_ARRAY_DESCRIPTOR; +} CUDA_ARRAY_DESCRIPTOR_v2; +typedef CUDA_ARRAY_DESCRIPTOR_v2 CUDA_ARRAY_DESCRIPTOR; /** * 3D array descriptor @@ -1834,11 +2208,39 @@ typedef struct CUDA_ARRAY3D_DESCRIPTOR_st CUarray_format Format; /**< Array format */ unsigned int NumChannels; /**< Channels per array element */ unsigned int Flags; /**< Flags */ -} CUDA_ARRAY3D_DESCRIPTOR; +} CUDA_ARRAY3D_DESCRIPTOR_v2; +typedef CUDA_ARRAY3D_DESCRIPTOR_v2 CUDA_ARRAY3D_DESCRIPTOR; -#endif /* __CUDA_API_VERSION >= 3020 */ +/** + * Indicates that the layered sparse CUDA array or CUDA mipmapped array has a single mip tail region for all layers + */ +#define CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL 0x1 -#if __CUDA_API_VERSION >= 5000 +/** + * CUDA array sparse properties + */ +typedef struct CUDA_ARRAY_SPARSE_PROPERTIES_st { + struct { + unsigned int width; /**< Width of sparse tile in elements */ + unsigned int height; /**< Height of sparse tile in elements */ + unsigned int depth; /**< Depth of sparse tile in elements */ + } tileExtent; + + /** + * First mip level at which the mip tail begins. + */ + unsigned int miptailFirstLevel; + /** + * Total size of the mip tail. + */ + unsigned long long miptailSize; + /** + * Flags will either be zero or ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL + */ + unsigned int flags; + unsigned int reserved[4]; +} CUDA_ARRAY_SPARSE_PROPERTIES_v1; +typedef CUDA_ARRAY_SPARSE_PROPERTIES_v1 CUDA_ARRAY_SPARSE_PROPERTIES; /** * CUDA Resource descriptor @@ -1874,7 +2276,8 @@ typedef struct CUDA_RESOURCE_DESC_st } res; unsigned int flags; /**< Flags (must be zero) */ -} CUDA_RESOURCE_DESC; +} CUDA_RESOURCE_DESC_v1; +typedef CUDA_RESOURCE_DESC_v1 CUDA_RESOURCE_DESC; /** * Texture descriptor @@ -1890,7 +2293,8 @@ typedef struct CUDA_TEXTURE_DESC_st { float maxMipmapLevelClamp; /**< Mipmap maximum level clamp */ float borderColor[4]; /**< Border Color */ int reserved[12]; -} CUDA_TEXTURE_DESC; +} CUDA_TEXTURE_DESC_v1; +typedef CUDA_TEXTURE_DESC_v1 CUDA_TEXTURE_DESC; /** * Resource view format @@ -1948,7 +2352,8 @@ typedef struct CUDA_RESOURCE_VIEW_DESC_st unsigned int firstLayer; /**< First layer index */ unsigned int lastLayer; /**< Last layer index */ unsigned int reserved[16]; -} CUDA_RESOURCE_VIEW_DESC; +} CUDA_RESOURCE_VIEW_DESC_v1; +typedef CUDA_RESOURCE_VIEW_DESC_v1 CUDA_RESOURCE_VIEW_DESC; /** * GPU Direct v3 tokens @@ -1956,11 +2361,18 @@ typedef struct CUDA_RESOURCE_VIEW_DESC_st typedef struct CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_st { unsigned long long p2pToken; unsigned int vaSpaceToken; -} CUDA_POINTER_ATTRIBUTE_P2P_TOKENS; +} CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1; +typedef CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1 CUDA_POINTER_ATTRIBUTE_P2P_TOKENS; -#endif /* __CUDA_API_VERSION >= 5000 */ - -#if __CUDA_API_VERSION >= 9000 +/** +* Access flags that specify the level of access the current context's device has +* on the memory referenced. +*/ +typedef enum CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS_enum { + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_NONE = 0x0, /**< No access, meaning the device cannot access this memory at all, thus must be staged through accessible memory in order to complete certain operations */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READ = 0x1, /**< Read-only access, meaning writes to this memory are considered invalid accesses and thus return error in that case. */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READWRITE = 0x3 /**< Read-write access, the device has full read-write access to the memory */ +} CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS; /** * Kernel launch parameters @@ -1976,11 +2388,8 @@ typedef struct CUDA_LAUNCH_PARAMS_st { unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ CUstream hStream; /**< Stream identifier */ void **kernelParams; /**< Array of pointers to kernel parameters */ -} CUDA_LAUNCH_PARAMS; - -#endif /* __CUDA_API_VERSION >= 9000 */ - -#if __CUDA_API_VERSION >= 10000 +} CUDA_LAUNCH_PARAMS_v1; +typedef CUDA_LAUNCH_PARAMS_v1 CUDA_LAUNCH_PARAMS; /** * External memory handle types @@ -1989,23 +2398,35 @@ typedef enum CUexternalMemoryHandleType_enum { /** * Handle is an opaque file descriptor */ - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, /** * Handle is an opaque shared NT handle */ - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, /** * Handle is an opaque, globally shared handle */ - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, /** * Handle is a D3D12 heap object */ - CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, /** * Handle is a D3D12 committed resource */ - CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5 + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + /** + * Handle is a shared NT handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + /** + * Handle is a globally shared handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + /** + * Handle is an NvSciBuf object + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 } CUexternalMemoryHandleType; /** @@ -2013,6 +2434,37 @@ typedef enum CUexternalMemoryHandleType_enum { */ #define CUDA_EXTERNAL_MEMORY_DEDICATED 0x1 +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS + * contains this flag, it indicates that signaling an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, + * which otherwise are performed by default to ensure data coherency with other + * importers of the same NvSciBuf memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC 0x01 + +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS + * contains this flag, it indicates that waiting on an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, + * which otherwise are performed by default to ensure data coherency with other + * importers of the same NvSciBuf memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC 0x02 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs signaler specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_SIGNAL 0x1 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs waiter specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_WAIT 0x2 /** * External memory handle descriptor */ @@ -2035,9 +2487,12 @@ typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT * Exactly one of 'handle' and 'name' must be non-NULL. If - * type is - * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * type is one of the following: + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT * then 'name' must be NULL. */ struct { @@ -2051,6 +2506,11 @@ typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { */ const void *name; } win32; + /** + * A handle representing an NvSciBuf Object. Valid when type + * is ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF + */ + const void *nvSciBufObject; } handle; /** * Size of the memory allocation @@ -2061,7 +2521,8 @@ typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { */ unsigned int flags; unsigned int reserved[16]; -} CUDA_EXTERNAL_MEMORY_HANDLE_DESC; +} CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1 CUDA_EXTERNAL_MEMORY_HANDLE_DESC; /** * External memory buffer descriptor @@ -2080,7 +2541,8 @@ typedef struct CUDA_EXTERNAL_MEMORY_BUFFER_DESC_st { */ unsigned int flags; unsigned int reserved[16]; -} CUDA_EXTERNAL_MEMORY_BUFFER_DESC; +} CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1 CUDA_EXTERNAL_MEMORY_BUFFER_DESC; /** * External memory mipmap descriptor @@ -2100,7 +2562,8 @@ typedef struct CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_st { */ unsigned int numLevels; unsigned int reserved[16]; -} CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; +} CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1 CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; /** * External semaphore handle types @@ -2109,19 +2572,43 @@ typedef enum CUexternalSemaphoreHandleType_enum { /** * Handle is an opaque file descriptor */ - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, /** * Handle is an opaque shared NT handle */ - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, /** * Handle is an opaque, globally shared handle */ - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, /** * Handle is a shared NT handle referencing a D3D12 fence object */ - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4 + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + /** + * Handle is a shared NT handle referencing a D3D11 fence object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + /** + * Opaque handle to NvSciSync Object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + /** + * Handle is a shared NT handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + /** + * Handle is a globally shared handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + /** + * Handle is an opaque file descriptor referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + /** + * Handle is an opaque shared NT handle referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 } CUexternalSemaphoreHandleType; /** @@ -2135,8 +2622,9 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { union { /** * File descriptor referencing the semaphore object. Valid - * when type is - * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD + * when type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD */ int fd; /** @@ -2145,9 +2633,13 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 * Exactly one of 'handle' and 'name' must be non-NULL. If - * type is - * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT * then 'name' must be NULL. */ struct { @@ -2161,13 +2653,18 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { */ const void *name; } win32; + /** + * Valid NvSciSyncObj. Must be non NULL + */ + const void* nvSciSyncObj; } handle; /** * Flags reserved for the future. Must be zero. */ unsigned int flags; unsigned int reserved[16]; -} CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; +} CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1 CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; /** * External semaphore signal parameters @@ -2183,14 +2680,39 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_st { */ unsigned long long value; } fence; - unsigned int reserved[16]; + union { + /** + * Pointer to NvSciSyncFence. Valid if ::CUexternalSemaphoreHandleType + * is of type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to release the mutex with + */ + unsigned long long key; + } keyedMutex; + unsigned int reserved[12]; } params; /** - * Flags reserved for the future. Must be zero. + * Only when ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS is used to + * signal a ::CUexternalSemaphore of type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, the valid flag is + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC which indicates + * that while signaling the ::CUexternalSemaphore, no memory synchronization + * operations should be performed for any external memory object imported + * as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. + * For all other types of ::CUexternalSemaphore, flags must be zero. */ unsigned int flags; unsigned int reserved[16]; -} CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS; +} CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1 CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS; /** * External semaphore wait parameters @@ -2206,17 +2728,400 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { */ unsigned long long value; } fence; - unsigned int reserved[16]; + /** + * Pointer to NvSciSyncFence. Valid if CUexternalSemaphoreHandleType + * is of type CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + union { + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to acquire the mutex with + */ + unsigned long long key; + /** + * Timeout in milliseconds to wait to acquire the mutex + */ + unsigned int timeoutMs; + } keyedMutex; + unsigned int reserved[10]; } params; /** - * Flags reserved for the future. Must be zero. + * Only when ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS is used to wait on + * a ::CUexternalSemaphore of type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * the valid flag is ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC + * which indicates that while waiting for the ::CUexternalSemaphore, no memory + * synchronization operations should be performed for any external memory + * object imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. + * For all other types of ::CUexternalSemaphore, flags must be zero. */ unsigned int flags; unsigned int reserved[16]; -} CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; +} CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1 CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; +/** + * Semaphore signal node parameters + */ +typedef struct CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_st { + CUexternalSemaphore* extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS* paramsArray; /**< Array of external semaphore signal parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in extSemArray and paramsArray. */ +} CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1 CUDA_EXT_SEM_SIGNAL_NODE_PARAMS; -#endif /* __CUDA_API_VERSION >= 10000 */ +/** + * Semaphore wait node parameters + */ +typedef struct CUDA_EXT_SEM_WAIT_NODE_PARAMS_st { + CUexternalSemaphore* extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS* paramsArray; /**< Array of external semaphore wait parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in extSemArray and paramsArray. */ +} CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1 CUDA_EXT_SEM_WAIT_NODE_PARAMS; + +typedef unsigned long long CUmemGenericAllocationHandle_v1; +typedef CUmemGenericAllocationHandle_v1 CUmemGenericAllocationHandle; + +/** + * Flags for specifying particular handle types + */ +typedef enum CUmemAllocationHandleType_enum { + CU_MEM_HANDLE_TYPE_NONE = 0x0, /**< Does not allow any export mechanism. > */ + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = 0x1, /**< Allows a file descriptor to be used for exporting. Permitted only on POSIX systems. (int) */ + CU_MEM_HANDLE_TYPE_WIN32 = 0x2, /**< Allows a Win32 NT handle to be used for exporting. (HANDLE) */ + CU_MEM_HANDLE_TYPE_WIN32_KMT = 0x4, /**< Allows a Win32 KMT handle to be used for exporting. (D3DKMT_HANDLE) */ + CU_MEM_HANDLE_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationHandleType; + +/** + * Specifies the memory protection flags for mapping. + */ +typedef enum CUmemAccess_flags_enum { + CU_MEM_ACCESS_FLAGS_PROT_NONE = 0x0, /**< Default, make the address range not accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READ = 0x1, /**< Make the address range read accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READWRITE = 0x3, /**< Make the address range read-write accessible */ + CU_MEM_ACCESS_FLAGS_PROT_MAX = 0x7FFFFFFF +} CUmemAccess_flags; + +/** + * Specifies the type of location + */ +typedef enum CUmemLocationType_enum { + CU_MEM_LOCATION_TYPE_INVALID = 0x0, + CU_MEM_LOCATION_TYPE_DEVICE = 0x1, /**< Location is a device location, thus id is a device ordinal */ + CU_MEM_LOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemLocationType; + +/** +* Defines the allocation types available +*/ +typedef enum CUmemAllocationType_enum { + CU_MEM_ALLOCATION_TYPE_INVALID = 0x0, + + /** This allocation type is 'pinned', i.e. cannot migrate from its current + * location while the application is actively using it + */ + CU_MEM_ALLOCATION_TYPE_PINNED = 0x1, + CU_MEM_ALLOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationType; + +/** +* Flag for requesting different optimal and required granularities for an allocation. +*/ +typedef enum CUmemAllocationGranularity_flags_enum { + CU_MEM_ALLOC_GRANULARITY_MINIMUM = 0x0, /**< Minimum required granularity for allocation */ + CU_MEM_ALLOC_GRANULARITY_RECOMMENDED = 0x1 /**< Recommended granularity for allocation for best performance */ +} CUmemAllocationGranularity_flags; + +/** + * Sparse subresource types + */ +typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 +} CUarraySparseSubresourceType; + +/** + * Memory operation types + */ +typedef enum CUmemOperationType_enum { + CU_MEM_OPERATION_TYPE_MAP = 1, + CU_MEM_OPERATION_TYPE_UNMAP = 2 +} CUmemOperationType; + +/** + * Memory handle types + */ +typedef enum CUmemHandleType_enum { + CU_MEM_HANDLE_TYPE_GENERIC = 0 +} CUmemHandleType; + +/** + * Specifies the CUDA array or CUDA mipmapped array memory mapping information + */ +typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; /**< Resource type */ + + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; /**< Sparse subresource type */ + + union { + struct { + unsigned int level; /**< For CUDA mipmapped arrays must a valid mipmap level. For CUDA arrays must be zero */ + unsigned int layer; /**< For CUDA layered arrays must be a valid layer index. Otherwise, must be zero */ + unsigned int offsetX; /**< Starting X offset in elements */ + unsigned int offsetY; /**< Starting Y offset in elements */ + unsigned int offsetZ; /**< Starting Z offset in elements */ + unsigned int extentWidth; /**< Width in elements */ + unsigned int extentHeight; /**< Height in elements */ + unsigned int extentDepth; /**< Depth in elements */ + } sparseLevel; + struct { + unsigned int layer; /**< For CUDA layered arrays must be a valid layer index. Otherwise, must be zero */ + unsigned long long offset; /**< Offset within mip tail */ + unsigned long long size; /**< Extent in bytes */ + } miptail; + } subresource; + + CUmemOperationType memOperationType; /**< Memory operation type */ + CUmemHandleType memHandleType; /**< Memory handle type */ + + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; /**< Offset within the memory */ + unsigned int deviceBitMask; /**< Device ordinal bit mask */ + unsigned int flags; /**< flags for future use, must be zero now. */ + unsigned int reserved[2]; /**< Reserved for future use, must be zero now. */ +} CUarrayMapInfo_v1; +typedef CUarrayMapInfo_v1 CUarrayMapInfo; + +/** + * Specifies a memory location. + */ +typedef struct CUmemLocation_st { + CUmemLocationType type; /**< Specifies the location type, which modifies the meaning of id. */ + int id; /**< identifier for a given this location's ::CUmemLocationType. */ +} CUmemLocation_v1; +typedef CUmemLocation_v1 CUmemLocation; + +/** + * Specifies compression attribute for an allocation. + */ +typedef enum CUmemAllocationCompType_enum { + CU_MEM_ALLOCATION_COMP_NONE = 0x0, /**< Allocating non-compressible memory */ + CU_MEM_ALLOCATION_COMP_GENERIC = 0x1 /**< Allocating compressible memory */ +} CUmemAllocationCompType; + +/** + * This flag if set indicates that the memory will be used as a tile pool. + */ +#define CU_MEM_CREATE_USAGE_TILE_POOL 0x1 + +/** +* Specifies the allocation properties for a allocation. +*/ +typedef struct CUmemAllocationProp_st { + /** Allocation type */ + CUmemAllocationType type; + /** requested ::CUmemAllocationHandleType */ + CUmemAllocationHandleType requestedHandleTypes; + /** Location of allocation */ + CUmemLocation location; + /** + * Windows-specific POBJECT_ATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure + * includes security attributes that define + * the scope of which exported allocations may be tranferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32HandleMetaData; + struct { + /** + * Allocation hint for requesting compressible memory. + * On devices that support Compute Data Compression, compressible + * memory can be used to accelerate accesses to data with unstructured + * sparsity and other compressible data patterns. Applications are + * expected to query allocation property of the handle obtained with + * ::cuMemCreate using ::cuMemGetAllocationPropertiesFromHandle to + * validate if the obtained allocation is compressible or not. Note that + * compressed memory may not be mappable on all devices. + */ + unsigned char compressionType; + unsigned char gpuDirectRDMACapable; + /** Bitmask indicating intended usage for this allocation */ + unsigned short usage; + unsigned char reserved[4]; + } allocFlags; +} CUmemAllocationProp_v1; +typedef CUmemAllocationProp_v1 CUmemAllocationProp; + +/** + * Memory access descriptor + */ +typedef struct CUmemAccessDesc_st { + CUmemLocation location; /**< Location on which the request is to change it's accessibility */ + CUmemAccess_flags flags; /**< ::CUmemProt accessibility flags to set on the request */ +} CUmemAccessDesc_v1; +typedef CUmemAccessDesc_v1 CUmemAccessDesc; + +typedef enum CUgraphExecUpdateResult_enum { + CU_GRAPH_EXEC_UPDATE_SUCCESS = 0x0, /**< The update succeeded */ + CU_GRAPH_EXEC_UPDATE_ERROR = 0x1, /**< The update failed for an unexpected reason which is described in the return value of the function */ + CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED = 0x2, /**< The update failed because the topology changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED = 0x3, /**< The update failed because a node type changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED = 0x4, /**< The update failed because the function of a kernel node changed (CUDA driver < 11.2) */ + CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED = 0x5, /**< The update failed because the parameters changed in a way that is not supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED = 0x6, /**< The update failed because something about the node is not supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE = 0x7 /**< The update failed because the function of a kernel node changed in an unsupported way */ +} CUgraphExecUpdateResult; + +/** + * CUDA memory pool attributes + */ +typedef enum CUmemPool_attribute_enum { + /** + * (value type = int) + * Allow cuMemAllocAsync to use memory asynchronously freed + * in another streams as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES = 1, + + /** + * (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, + + /** + * (value type = int) + * Allow cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by cuFreeAsync (default enabled). + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES, + + /** + * (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + */ + CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + + /** + * (value type = cuuint64_t) + * Amount of backing memory currently allocated for the mempool. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by the application. + */ + CU_MEMPOOL_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that was in use by the application since + * the last time it was reset. High watermark can only be reset to zero. + */ + CU_MEMPOOL_ATTR_USED_MEM_HIGH +} CUmemPool_attribute; + +/** + * Specifies the properties of allocations made from the pool. + */ +typedef struct CUmemPoolProps_st { + CUmemAllocationType allocType; /**< Allocation type. Currently must be specified as CU_MEM_ALLOCATION_TYPE_PINNED */ + CUmemAllocationHandleType handleTypes; /**< Handle types that will be supported by allocations from the pool. */ + CUmemLocation location; /**< Location where allocations should reside. */ + /** + * Windows-specific LPSECURITYATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines + * the scope of which exported allocations may be tranferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32SecurityAttributes; + unsigned char reserved[64]; /**< reserved for future use, must be 0 */ +} CUmemPoolProps_v1; +typedef CUmemPoolProps_v1 CUmemPoolProps; + +/** + * Opaque data for exporting a pool allocation + */ +typedef struct CUmemPoolPtrExportData_st { + unsigned char reserved[64]; +} CUmemPoolPtrExportData_v1; +typedef CUmemPoolPtrExportData_v1 CUmemPoolPtrExportData; + +/** + * Memory allocation node parameters + */ +typedef struct CUDA_MEM_ALLOC_NODE_PARAMS_st { + /** + * in: location where the allocation should reside (specified in ::location). + * ::handleTypes must be ::CU_MEM_HANDLE_TYPE_NONE. IPC is not supported. + */ + CUmemPoolProps poolProps; + const CUmemAccessDesc *accessDescs; /**< in: array of memory access descriptors. Used to describe peer GPU access */ + size_t accessDescCount; /**< in: number of memory access descriptors. Must not exceed the number of GPUs. */ + size_t bytesize; /**< in: size in bytes of the requested allocation */ + CUdeviceptr dptr; /**< out: address of the allocation returned by CUDA */ +} CUDA_MEM_ALLOC_NODE_PARAMS; + +typedef enum CUgraphMem_attribute_enum { + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently associated with graphs + */ + CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH +} CUgraphMem_attribute; /** * If set, each kernel launched as part of ::cuLaunchCooperativeKernelMultiDevice only @@ -2276,6 +3181,12 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { */ #define CUDA_ARRAY3D_COLOR_ATTACHMENT 0x20 +/** + * This flag if set indicates that the CUDA array or CUDA mipmapped array + * is a sparse CUDA array or CUDA mipmapped array respectively + */ +#define CUDA_ARRAY3D_SPARSE 0x40 + /** * Override the texref format with a format inferred from the array. * Flag for ::cuTexRefSetArray() @@ -2285,22 +3196,28 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { /** * Read the texture as integers rather than promoting the values to floats * in the range [0,1]. - * Flag for ::cuTexRefSetFlags() + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() */ #define CU_TRSF_READ_AS_INTEGER 0x01 /** * Use normalized texture coordinates in the range [0,1) instead of [0,dim). - * Flag for ::cuTexRefSetFlags() + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() */ #define CU_TRSF_NORMALIZED_COORDINATES 0x02 /** * Perform sRGB->linear conversion during texture read. - * Flag for ::cuTexRefSetFlags() + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() */ #define CU_TRSF_SRGB 0x10 + /** + * Disable any trilinear filtering optimizations. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION 0x20 + /** * End of array terminator for the \p extra parameter to * ::cuLaunchKernel @@ -2344,8 +3261,86 @@ typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { */ #define CU_DEVICE_INVALID ((CUdevice)-2) +/** + * Bitmasks for ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS + */ +typedef enum CUflushGPUDirectRDMAWritesOptions_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_HOST = 1<<0, /**< ::cuFlushGPUDirectRDMAWrites() and its CUDA Runtime API counterpart are supported on the device. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_MEMOPS = 1<<1 /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the device. */ +} CUflushGPUDirectRDMAWritesOptions; + +/** + * Platform native ordering for GPUDirect RDMA writes + */ +typedef enum CUGPUDirectRDMAWritesOrdering_enum { + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_NONE = 0, /**< The device does not natively support ordering of remote writes. ::cuFlushGPUDirectRDMAWrites() can be leveraged if supported. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_OWNER = 100, /**< Natively, the device can consistently consume remote writes, although other CUDA devices may not. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_ALL_DEVICES = 200 /**< Any CUDA device in the system can consistently consume remote writes to this device. */ +} CUGPUDirectRDMAWritesOrdering; + +/** + * The scopes for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesScope_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER = 100, /**< Blocks until remote writes are visible to the CUDA device context owning the data. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_ALL_DEVICES = 200 /**< Blocks until remote writes are visible to all CUDA device contexts. */ +} CUflushGPUDirectRDMAWritesScope; + +/** + * The targets for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesTarget_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX = 0 /**< Sets the target for ::cuFlushGPUDirectRDMAWrites() to the currently active CUDA device context. */ +} CUflushGPUDirectRDMAWritesTarget; + +/** + * The additional write options for ::cuGraphDebugDotPrint + */ +typedef enum CUgraphDebugDot_flags_enum { + CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE = 1<<0, /** Output all debug data as if every debug flag is enabled */ + CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES = 1<<1, /** Use CUDA Runtime structures for output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS = 1<<2, /** Adds CUDA_KERNEL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS = 1<<3, /** Adds CUDA_MEMCPY3D values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS = 1<<4, /** Adds CUDA_MEMSET_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS = 1<<5, /** Adds CUDA_HOST_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS = 1<<6, /** Adds CUevent handle from record and wait nodes to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS = 1<<7, /** Adds CUDA_EXT_SEM_SIGNAL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS = 1<<8, /** Adds CUDA_EXT_SEM_WAIT_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES = 1<<9, /** Adds CUkernelNodeAttrValue values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES = 1<<10, /** Adds node handles and every kernel function handle to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS = 1<<11, /** Adds memory alloc node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS = 1<<12 /** Adds memory free node parameters to output */ +} CUgraphDebugDot_flags; + +/** + * Flags for user objects for graphs + */ +typedef enum CUuserObject_flags_enum { + CU_USER_OBJECT_NO_DESTRUCTOR_SYNC = 1 /**< Indicates the destructor execution is not synchronized by any CUDA handle. */ +} CUuserObject_flags; + +/** + * Flags for retaining user object references for graphs + */ +typedef enum CUuserObjectRetain_flags_enum { + CU_GRAPH_USER_OBJECT_MOVE = 1 /**< Transfer references from the caller rather than creating new references. */ +} CUuserObjectRetain_flags; + +/** + * Flags for instantiating a graph + */ +typedef enum CUgraphInstantiate_flags_enum { + CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH = 1 /**< Automatically free memory allocated in a graph before relaunching. */ +} CUgraphInstantiate_flags; + /** @} */ /* END CUDA_TYPES */ +#if defined(__GNUC__) + #if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) + #pragma GCC visibility push(default) + #endif +#endif + #ifdef _WIN32 #define CUDAAPI __stdcall #else @@ -2433,7 +3428,9 @@ CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr); * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_INVALID_DEVICE + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_SYSTEM_DRIVER_MISMATCH, + * ::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE * \notefnerr */ CUresult CUDAAPI cuInit(unsigned int Flags); @@ -2514,7 +3511,8 @@ CUresult CUDAAPI cuDriverGetVersion(int *driverVersion); * ::cuDeviceGetName, * ::cuDeviceGetUuid, * ::cuDeviceGetLuid, - * ::cuDeviceTotalMem + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport */ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); @@ -2542,6 +3540,7 @@ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); * ::cuDeviceGetLuid, * ::cuDeviceGet, * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, * ::cudaGetDeviceCount */ CUresult CUDAAPI cuDeviceGetCount(int *count); @@ -2573,14 +3572,17 @@ CUresult CUDAAPI cuDeviceGetCount(int *count); * ::cuDeviceGetCount, * ::cuDeviceGet, * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, * ::cudaGetDeviceProperties */ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); -#if __CUDA_API_VERSION >= 9020 /** * \brief Return an UUID for the device * + * Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will + * supplant this version in 12.0, which is retained for minor version compatibility. + * * Returns 16-octets identifing the device \p dev in the structure * pointed by the \p uuid. * @@ -2596,6 +3598,37 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); * \notefnerr * * \sa + * ::cuDeviceGetUuid_v2 + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); + +/** + * \brief Return an UUID for the device (11.4+) + * + * Returns 16-octets identifing the device \p dev in the structure + * pointed by the \p uuid. If the device is in MIG mode, returns its + * MIG UUID which uniquely identifies the subscribed MIG compute instance. + * + * \param uuid - Returned UUID + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa * ::cuDeviceGetAttribute, * ::cuDeviceGetCount, * ::cuDeviceGetName, @@ -2604,10 +3637,8 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); * ::cuDeviceTotalMem, * ::cudaGetDeviceProperties */ -CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); -#endif +CUresult CUDAAPI cuDeviceGetUuid_v2(CUuuid *uuid, CUdevice dev); -#if defined(_WIN32) && __CUDA_API_VERSION >= 10000 /** * \brief Return an LUID and device node mask for the device * @@ -2632,12 +3663,11 @@ CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); * ::cuDeviceGetName, * ::cuDeviceGet, * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, * ::cudaGetDeviceProperties */ CUresult CUDAAPI cuDeviceGetLuid(char *luid, unsigned int *deviceNodeMask, CUdevice dev); -#endif -#if __CUDA_API_VERSION >= 3020 /** * \brief Returns the total amount of memory on the device * @@ -2662,10 +3692,41 @@ CUresult CUDAAPI cuDeviceGetLuid(char *luid, unsigned int *deviceNodeMask, CUdev * ::cuDeviceGetName, * ::cuDeviceGetUuid, * ::cuDeviceGet, + * ::cuDeviceGetExecAffinitySupport, * ::cudaMemGetInfo */ CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); -#endif /* __CUDA_API_VERSION >= 3020 */ + +/** + * \brief Returns the maximum number of elements allocatable in a 1D linear texture for a given texture element size. + * + * Returns in \p maxWidthInElements the maximum number of texture elements allocatable in a 1D linear texture + * for given \p format and \p numChannels. + * + * \param maxWidthInElements - Returned maximum number of texture elements allocatable for given \p format and \p numChannels. + * \param format - Texture format. + * \param numChannels - Number of channels per texture element. + * \param dev - Device handle. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cudaMemGetInfo, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, CUarray_format format, unsigned numChannels, CUdevice dev); /** * \brief Returns information about the device @@ -2674,117 +3735,117 @@ CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); * \p dev. The supported attributes are: * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per * block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid; - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid; - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of - * shared memory available to a thread block in bytes; + * shared memory available to a thread block in bytes * - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for - * __constant__ variables in a CUDA C kernel in bytes; - * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads; + * __constant__ variables in a CUDA C kernel in bytes + * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads * - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the * memory copy functions that involve memory regions allocated through - * ::cuMemAllocPitch(); + * ::cuMemAllocPitch() * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D - * texture width; + * texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width - * for a 1D texture bound to linear memory; + * for a 1D texture bound to linear memory * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum - * mipmapped 1D texture width; + * mipmapped 1D texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D - * texture width; + * texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D - * texture height; + * texture height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width - * for a 2D texture bound to linear memory; + * for a 2D texture bound to linear memory * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height - * for a 2D texture bound to linear memory; + * for a 2D texture bound to linear memory * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch - * in bytes for a 2D texture bound to linear memory; + * in bytes for a 2D texture bound to linear memory * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum - * mipmapped 2D texture width; + * mipmapped 2D texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum - * mipmapped 2D texture height; + * mipmapped 2D texture height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D - * texture width; + * texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D - * texture height; + * texture height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D - * texture depth; + * texture depth * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE: * Alternate maximum 3D texture width, 0 if no alternate - * maximum 3D texture size is supported; + * maximum 3D texture size is supported * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE: * Alternate maximum 3D texture height, 0 if no alternate - * maximum 3D texture size is supported; + * maximum 3D texture size is supported * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE: * Alternate maximum 3D texture depth, 0 if no alternate - * maximum 3D texture size is supported; + * maximum 3D texture size is supported * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH: - * Maximum cubemap texture width or height; + * Maximum cubemap texture width or height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH: - * Maximum 1D layered texture width; + * Maximum 1D layered texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered texture; + * Maximum layers in a 1D layered texture * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH: - * Maximum 2D layered texture width; + * Maximum 2D layered texture width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT: - * Maximum 2D layered texture height; + * Maximum 2D layered texture height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered texture; + * Maximum layers in a 2D layered texture * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered texture width or height; + * Maximum cubemap layered texture width or height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered texture; + * Maximum layers in a cubemap layered texture * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH: - * Maximum 1D surface width; + * Maximum 1D surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH: - * Maximum 2D surface width; + * Maximum 2D surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT: - * Maximum 2D surface height; + * Maximum 2D surface height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH: - * Maximum 3D surface width; + * Maximum 3D surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT: - * Maximum 3D surface height; + * Maximum 3D surface height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH: - * Maximum 3D surface depth; + * Maximum 3D surface depth * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH: - * Maximum 1D layered surface width; + * Maximum 1D layered surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered surface; + * Maximum layers in a 1D layered surface * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH: - * Maximum 2D layered surface width; + * Maximum 2D layered surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT: - * Maximum 2D layered surface height; + * Maximum 2D layered surface height * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered surface; + * Maximum layers in a 2D layered surface * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH: - * Maximum cubemap surface width; + * Maximum cubemap surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered surface width; + * Maximum cubemap layered surface width * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered surface; + * Maximum layers in a cubemap layered surface * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit - * registers available to a thread block; - * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz; + * registers available to a thread block + * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture * base addresses aligned to ::textureAlign bytes do not need an offset - * applied to texture fetches; + * applied to texture fetches * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement - * for 2D texture references bound to pitched memory; + * for 2D texture references bound to pitched memory * - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy - * memory between host and device while executing a kernel, or 0 if not; + * memory between host and device while executing a kernel, or 0 if not * - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on - * the device; + * the device * - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit - * for kernels executed on the device, or 0 if not; + * for kernels executed on the device, or 0 if not * - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the - * memory subsystem, or 0 if not; + * memory subsystem, or 0 if not * - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host - * memory into the CUDA address space, or 0 if not; + * memory into the CUDA address space, or 0 if not * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently * in. Available modes are as follows: * - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and @@ -2797,33 +3858,33 @@ CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); * executing multiple kernels within the same context simultaneously, or 0 if * not. It is not guaranteed that multiple kernels will be resident * on the device concurrently so this feature should not be relied upon for - * correctness; + * correctness. * - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the - * device, 0 if error correction is disabled or not supported by the device; - * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device; + * device, 0 if error correction is disabled or not supported by the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device * - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier - * of the device; + * of the device * - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device * - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC - * is only available on Tesla hardware running Windows Vista or later; - * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz; - * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits; - * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache; - * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor; + * is only available on Tesla hardware running Windows Vista or later + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits + * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor * - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with - * the host, or 0 if not; - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number; - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number; + * the host, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals - * in L1 cache, 0 if caching globals in L1 cache is not supported by the device; + * in L1 cache, 0 if caching globals in L1 cache is not supported by the device * - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals - * in L1 cache, 0 if caching locals in L1 cache is not supported by the device; + * in L1 cache, 0 if caching locals in L1 cache is not supported by the device * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of * shared memory available to a multiprocessor in bytes; this amount is shared - * by all thread blocks simultaneously resident on a multiprocessor; + * by all thread blocks simultaneously resident on a multiprocessor * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit * registers available to a multiprocessor; this number is shared by all thread - * blocks simultaneously resident on a multiprocessor; + * blocks simultaneously resident on a multiprocessor * - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory * on this system, 0 if allocating managed memory is not supported by the device on this system. * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not. @@ -2846,6 +3907,23 @@ CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's * page tables. * - ::CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST: The host can directly access managed memory on the device without migration. + * - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes + * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate. + * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes + * - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays. + * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. + * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC * * \param pi - Returned device attribute value * \param attrib - Device attribute to query @@ -2866,11 +3944,144 @@ CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); * ::cuDeviceGetUuid, * ::cuDeviceGet, * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, * ::cudaDeviceGetAttribute, * ::cudaGetDeviceProperties */ CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); +/** + * \brief Return NvSciSync attributes that this device can support. + * + * Returns in \p nvSciSyncAttrList, the properties of NvSciSync that + * this CUDA device, \p dev can support. The returned \p nvSciSyncAttrList + * can be used to create an NvSciSync object that matches this device's capabilities. + * + * If NvSciSyncAttrKey_RequiredPerm field in \p nvSciSyncAttrList is + * already set this API will return ::CUDA_ERROR_INVALID_VALUE. + * + * The applications should set \p nvSciSyncAttrList to a valid + * NvSciSyncAttrList failing which this API will return + * ::CUDA_ERROR_INVALID_HANDLE. + * + * The \p flags controls how applications intends to use + * the NvSciSync created from the \p nvSciSyncAttrList. The valid flags are: + * - ::CUDA_NVSCISYNC_ATTR_SIGNAL, specifies that the applications intends to + * signal an NvSciSync on this CUDA device. + * - ::CUDA_NVSCISYNC_ATTR_WAIT, specifies that the applications intends to + * wait on an NvSciSync on this CUDA device. + * + * At least one of these flags must be set, failing which the API + * returns ::CUDA_ERROR_INVALID_VALUE. Both the flags are orthogonal + * to one another: a developer may set both these flags that allows to + * set both wait and signal specific attributes in the same \p nvSciSyncAttrList. + * + * \param nvSciSyncAttrList - Return NvSciSync attributes supported. + * \param dev - Valid Cuda Device to get NvSciSync attributes for. + * \param flags - flags describing NvSciSync usage. + * + * \return + * + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa + * ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuDeviceGetNvSciSyncAttributes(void *nvSciSyncAttrList, CUdevice dev, int flags); + +/** + * \brief Sets the current memory pool of a device + * + * The memory pool must be local to the specified device. + * ::cuMemAllocAsync allocates from the current mempool of the provided stream's device. + * By default, a device's current memory pool is its default memory pool. + * + * \note Use ::cuMemAllocFromPoolAsync to specify asynchronous allocations from a device different + * than the one the stream runs on. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, ::cuMemPoolDestroy, ::cuMemAllocFromPoolAsync + */ +CUresult CUDAAPI cuDeviceSetMemPool(CUdevice dev, CUmemoryPool pool); + +/** + * \brief Gets the current mempool for a device + * + * Returns the last pool provided to ::cuDeviceSetMemPool for this device + * or the device's default memory pool if ::cuDeviceSetMemPool has never been called. + * By default the current mempool is the default mempool for a device. + * Otherwise the returned pool must have been set with ::cuDeviceSetMemPool. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate, ::cuDeviceSetMemPool + */ +CUresult CUDAAPI cuDeviceGetMemPool(CUmemoryPool *pool, CUdevice dev); + +/** + * \brief Returns the default mempool of a device + * + * The default mempool of a device contains device memory from that device. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemAllocAsync, ::cuMemPoolTrimTo, ::cuMemPoolGetAttribute, ::cuMemPoolSetAttribute, cuMemPoolSetAccess, ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuDeviceGetDefaultMemPool(CUmemoryPool *pool_out, CUdevice dev); + +/** + * \brief Blocks until remote writes are visible to the specified scope + * + * Blocks until GPUDirect RDMA writes to the target context via mappings + * created through APIs like nvidia_p2p_get_pages (see + * https://docs.nvidia.com/cuda/gpudirect-rdma for more information), are + * visible to the specified scope. + * + * If the scope equals or lies within the scope indicated by + * ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING, the call + * will be a no-op and can be safely omitted for performance. This can be + * determined by comparing the numerical values between the two enums, with + * smaller scopes having smaller values. + * + * Users may query support for this API via + * ::CU_DEVICE_ATTRIBUTE_FLUSH_FLUSH_GPU_DIRECT_RDMA_OPTIONS. + * + * \param target - The target of the operation, see ::CUflushGPUDirectRDMAWritesTarget + * \param scope - The scope of the operation, see ::CUflushGPUDirectRDMAWritesScope + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + */ +CUresult CUDAAPI cuFlushGPUDirectRDMAWrites(CUflushGPUDirectRDMAWritesTarget target, CUflushGPUDirectRDMAWritesScope scope); + /** @} */ /* END CUDA_DEVICE */ /** @@ -3000,20 +4211,19 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, int *mi * @{ */ -#if __CUDA_API_VERSION >= 7000 - /** * \brief Retain the primary context on the GPU * - * Retains the primary context on the device, creating it if necessary, - * increasing its usage count. The caller must call - * ::cuDevicePrimaryCtxRelease() when done using the context. - * Unlike ::cuCtxCreate() the newly created context is not pushed onto the stack. + * Retains the primary context on the device. + * Once the user successfully retains the primary context, the primary context + * will be active and available to the user until the user releases it + * with ::cuDevicePrimaryCtxRelease() or resets it with ::cuDevicePrimaryCtxReset(). + * Unlike ::cuCtxCreate() the newly retained context is not pushed onto the stack. * - * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of - * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() - * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the compute mode - * of the device. + * Retaining the primary context for the first time will fail with ::CUDA_ERROR_UNKNOWN + * if the compute mode of the device is ::CU_COMPUTEMODE_PROHIBITED. The function + * ::cuDeviceGetAttribute() can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to + * determine the compute mode of the device. * The nvidia-smi tool can be used to set the compute mode for * devices. Documentation for nvidia-smi can be obtained by passing a * -h option to it. @@ -3054,9 +4264,15 @@ CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev); /** * \brief Release the primary context on the GPU * - * Releases the primary context interop on the device by decreasing the usage - * count by 1. If the usage drops to 0 the primary context of device \p dev - * will be destroyed regardless of how many threads it is current to. + * Releases the primary context interop on the device. + * A retained context should always be released once the user is done using + * it. The context is automatically reset once the last reference to it is + * released. This behavior is different when the primary context was retained + * by the CUDA runtime from CUDA 4.0 and earlier. In this case, the primary + * context remains always active. + * + * Releasing a primary context that has not been previously retained will + * fail with ::CUDA_ERROR_INVALID_CONTEXT. * * Please note that unlike ::cuCtxDestroy() this method does not pop the context * from stack in any circumstances. @@ -3067,7 +4283,8 @@ CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev); * ::CUDA_SUCCESS, * ::CUDA_ERROR_DEINITIALIZED, * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_DEVICE + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_CONTEXT * \notefnerr * * \sa ::cuDevicePrimaryCtxRetain, @@ -3089,8 +4306,7 @@ CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); * \brief Set flags for the primary context * * Sets the flags for the primary context on the device overwriting perviously - * set ones. If the primary context is already created - * ::CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE is returned. + * set ones. * * The three LSBs of the \p flags parameter can be used to control how the OS * thread, which owns the CUDA context at the time of an API call, interacts @@ -3121,13 +4337,16 @@ CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); * \e C > \e P, then CUDA will yield to other OS threads when waiting for * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). - * However, on low power devices like Tegra, it always defaults to - * ::CU_CTX_SCHED_BLOCKING_SYNC. + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. * * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory * after resizing local memory for a kernel. This can prevent thrashing by * local memory allocations when launching many kernels with high local - * memory usage at the cost of potentially increased memory usage. + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. * * \param dev - Device for which the primary context flags are set * \param flags - New flags for the device @@ -3138,7 +4357,6 @@ CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); * ::CUDA_ERROR_NOT_INITIALIZED, * ::CUDA_ERROR_INVALID_DEVICE, * ::CUDA_ERROR_INVALID_VALUE, - * ::CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE * \notefnerr * * \sa ::cuDevicePrimaryCtxRetain, @@ -3186,6 +4404,8 @@ CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, i * it is recommended to use ::cuDevicePrimaryCtxRelease() in most cases. * However it is safe for other modules to call ::cuDevicePrimaryCtxRelease() * even after resetting the device. + * Resetting the primary context does not release it, an application that has + * retained the primary context should explicitly release its usage. * * \param dev - Device for which primary context is destroyed * @@ -3213,10 +4433,38 @@ CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, i */ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); -#endif /* __CUDA_API_VERSION >= 7000 */ - /** @} */ /* END CUDA_PRIMARY_CTX */ +/** + * \brief Returns information about the execution affinity support of the device. + * + * Returns in \p *pi whether execution affinity type \p type is supported by device \p dev. + * The supported types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: 1 if context with limited SMs is supported by the device, + * or 0 if not; + * + * \param pi - 1 if the execution affinity type \p type is supported by the device, or 0 if not + * \param type - Execution affinity type to query + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type, CUdevice dev); /** * \defgroup CUDA_CTX Context Management @@ -3227,13 +4475,17 @@ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); * This section describes the context management functions of the low-level * CUDA driver application programming interface. * + * Please note that some functions are described in + * \ref CUDA_PRIMARY_CTX "Primary Context Management" section. + * * @{ */ -#if __CUDA_API_VERSION >= 3020 /** * \brief Create a CUDA context * + * \note In most cases it is recommended to use ::cuDevicePrimaryCtxRetain. + * * Creates a new CUDA context and associates it with the calling thread. The * \p flags parameter is described below. The context is created with a usage * count of 1 and the caller of ::cuCtxCreate() must call ::cuCtxDestroy() or @@ -3270,8 +4522,9 @@ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); * \e C > \e P, then CUDA will yield to other OS threads when waiting for * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). - * However, on low power devices like Tegra, it always defaults to - * ::CU_CTX_SCHED_BLOCKING_SYNC. + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. * * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. * This flag must be set in order to allocate pinned host memory that is @@ -3280,7 +4533,10 @@ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory * after resizing local memory for a kernel. This can prevent thrashing by * local memory allocations when launching many kernels with high local - * memory usage at the cost of potentially increased memory usage. + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). * * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() @@ -3318,9 +4574,114 @@ CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); * ::cuCtxSynchronize */ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); -#endif /* __CUDA_API_VERSION >= 3020 */ -#if __CUDA_API_VERSION >= 4000 +/** + * \brief Create a CUDA context with execution affinity + * + * Creates a new CUDA context with execution affinity and associates it with + * the calling thread. The \p paramsArray and \p flags parameter are described below. + * The context is created with a usage count of 1 and the caller of ::cuCtxCreate() must + * call ::cuCtxDestroy() or when done using the context. If a context is already + * current to the thread, it is supplanted by the newly created context and may + * be restored by a subsequent call to ::cuCtxPopCurrent(). + * + * The type and the amount of execution resource the context can use is limited by \p paramsArray + * and \p numParams. The \p paramsArray is an array of \p CUexecAffinityParam and the \p numParams + * describes the size of the array. If two \p CUexecAffinityParam in the array have the same type, + * the latter execution affinity parameter overrides the former execution affinity parameter. + * The supported execution affinity types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT limits the portion of SMs that the context can use. The portion + * of SMs is specified as the number of SMs via \p CUexecAffinitySmCount. This limit will be internally + * rounded up to the next hardware-supported amount. Hence, it is imperative to query the actual execution + * affinity of the context via \p cuCtxGetExecAffinity after context creation. Currently, this attribute + * is only supported under Volta+ MPS. + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. + * This flag must be set in order to allocate pinned host memory that is + * accessible to the GPU. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). + * + * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of + * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() + * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the + * compute mode of the device. The nvidia-smi tool can be used to set + * the compute mode for * devices. + * Documentation for nvidia-smi can be obtained by passing a + * -h option to it. + * + * \param pctx - Returned context handle of the new context + * \param paramsArray - Execution affinity parameters + * \param numParams - Number of execution affinity parameters + * \param flags - Context creation flags + * \param dev - Device to create context on + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArray, int numParams, unsigned int flags, CUdevice dev); + /** * \brief Destroy a CUDA context * @@ -3329,6 +4690,13 @@ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); * It is the responsibility of the calling function to ensure that no API * call issues using \p ctx while ::cuCtxDestroy() is executing. * + * Destroys and cleans up all resources associated with the context. + * It is the caller's responsibility to ensure that the context or its resources + * are not accessed or passed in subsequent API calls and doing so will result in undefined behavior. + * These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent, + * ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref, + * ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore. + * * If \p ctx is current to the calling thread then \p ctx will also be * popped from the current thread's context stack (as though ::cuCtxPopCurrent() * were called). If \p ctx is current to other threads, then \p ctx will @@ -3358,9 +4726,7 @@ CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); * ::cuCtxSynchronize */ CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); -#endif /* __CUDA_API_VERSION >= 4000 */ -#if __CUDA_API_VERSION >= 4000 /** * \brief Pushes a context on the current CPU thread * @@ -3472,7 +4838,6 @@ CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx); * ::CUDA_SUCCESS, * ::CUDA_ERROR_DEINITIALIZED, * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * * \sa @@ -3482,7 +4847,6 @@ CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx); * ::cudaGetDevice */ CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx); -#endif /* __CUDA_API_VERSION >= 4000 */ /** * \brief Returns the device ID for the current context @@ -3514,7 +4878,6 @@ CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx); */ CUresult CUDAAPI cuCtxGetDevice(CUdevice *device); -#if __CUDA_API_VERSION >= 7000 /** * \brief Returns the flags for the current context * @@ -3535,14 +4898,13 @@ CUresult CUDAAPI cuCtxGetDevice(CUdevice *device); * ::cuCtxGetApiVersion, * ::cuCtxGetCacheConfig, * ::cuCtxGetCurrent, - * ::cuCtxGetDevice + * ::cuCtxGetDevice, * ::cuCtxGetLimit, * ::cuCtxGetSharedMemConfig, * ::cuCtxGetStreamPriorityRange, * ::cudaGetDeviceFlags */ CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags); -#endif /* __CUDA_API_VERSION >= 7000 */ /** * \brief Block for a context's tasks to complete @@ -3588,6 +4950,11 @@ CUresult CUDAAPI cuCtxSynchronize(void); * discussed here. * * - ::CU_LIMIT_STACK_SIZE controls the stack size in bytes of each GPU thread. + * The driver automatically increases the per-thread stack size + * for each kernel launch as needed. This size isn't reset back to the + * original value after each launch. Setting this value will take effect + * immediately, and if necessary, the device will block until all preceding + * requested tasks are complete. * * - ::CU_LIMIT_PRINTF_FIFO_SIZE controls the size in bytes of the FIFO used * by the ::printf() device system call. Setting ::CU_LIMIT_PRINTF_FIFO_SIZE @@ -3610,7 +4977,7 @@ CUresult CUDAAPI cuCtxSynchronize(void); * launch depth of 24. When setting this limit, keep in mind that additional * levels of sync depth require the driver to reserve large amounts of device * memory which can no longer be used for user allocations. If these - * reservations of device memory fail, ::cuCtxSetLimit will return + * reservations of device memory fail, ::cuCtxSetLimit() will return * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. * This limit is only applicable to devices of compute capability 3.5 and * higher. Attempting to set this limit on devices of compute capability less @@ -3627,13 +4994,21 @@ CUresult CUDAAPI cuCtxSynchronize(void); * runtime, this limit can be increased. Keep in mind that being able to * sustain additional pending launches will require the driver to reserve * larger amounts of device memory upfront which can no longer be used for - * allocations. If these reservations fail, ::cuCtxSetLimit will return + * allocations. If these reservations fail, ::cuCtxSetLimit() will return * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. * This limit is only applicable to devices of compute capability 3.5 and * higher. Attempting to set this limit on devices of compute capability less * than 3.5 will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being * returned. * + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity. + * Values can range from 0B to 128B. This is purely a performence hint and + * it can be ignored or clamped depending on the platform. + * + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for + * persisting L2 cache. This is purely a performance hint and it can be + * ignored or clamped depending on the platform. + * * \param limit - Limit to set * \param value - Size of limit * @@ -3675,6 +5050,8 @@ CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value); * child grid launches to complete. * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT: maximum number of outstanding * device runtime launches that can be made from this context. + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY: L2 cache fetch granularity. + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE: Persisting L2 cache size in bytes * * \param limit - Limit to query * \param pvalue - Returned size of limit @@ -3795,7 +5172,6 @@ CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig); */ CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config); -#if __CUDA_API_VERSION >= 4020 /** * \brief Returns the current shared memory configuration for the current context. * @@ -3890,7 +5266,6 @@ CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig); * ::cudaDeviceSetSharedMemConfig */ CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config); -#endif /** * \brief Gets the context's API version. @@ -3970,6 +5345,47 @@ CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version); */ CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, int *greatestPriority); +/** + * \brief Resets all persisting lines in cache to normal status. + * + * ::cuCtxResetPersistingL2Cache Resets all persisting lines in cache to normal + * status. Takes effect on function return. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuCtxResetPersistingL2Cache(void); + +/** + * \brief Returns the execution affinity setting for the current context. + * + * Returns in \p *pExecAffinity the current value of \p type. The supported + * ::CUexecAffinityType values are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: number of SMs the context is limited to use. + * + * \param type - Execution affinity type to query + * \param pExecAffinity - Returned execution affinity + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY + * \notefnerr + * + * \sa + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxGetExecAffinity(CUexecAffinityParam *pExecAffinity, CUexecAffinityType type); + + /** @} */ /* END CUDA_CTX */ /** @@ -4097,6 +5513,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_NOT_FOUND, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_FILE_NOT_FOUND, @@ -4136,6 +5553,7 @@ CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_NO_BINARY_FOR_GPU, * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, @@ -4179,6 +5597,7 @@ CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_NO_BINARY_FOR_GPU, * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, @@ -4220,6 +5639,7 @@ CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, unsigne * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_NOT_FOUND, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_NO_BINARY_FOR_GPU, @@ -4252,6 +5672,7 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * \notefnerr + * \note_destroy_ub * * \sa ::cuModuleGetFunction, * ::cuModuleGetGlobal, @@ -4293,7 +5714,6 @@ CUresult CUDAAPI cuModuleUnload(CUmodule hmod); */ CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); -#if __CUDA_API_VERSION >= 3020 /** * \brief Returns a global pointer from a module * @@ -4328,7 +5748,6 @@ CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const cha * ::cudaGetSymbolSize */ CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, CUmodule hmod, const char *name); -#endif /* __CUDA_API_VERSION >= 3020 */ /** * \brief Returns a handle to a texture reference @@ -4396,8 +5815,6 @@ CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, const char */ CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, const char *name); -#if __CUDA_API_VERSION >= 5050 - /** * \brief Creates a pending JIT linker invocation. * @@ -4465,6 +5882,7 @@ cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_IMAGE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_NO_BINARY_FOR_GPU * @@ -4504,6 +5922,7 @@ cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, c * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_IMAGE, * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_NO_BINARY_FOR_GPU * @@ -4556,8 +5975,6 @@ cuLinkComplete(CUlinkState state, void **cubinOut, size_t *sizeOut); CUresult CUDAAPI cuLinkDestroy(CUlinkState state); -#endif /* __CUDA_API_VERSION >= 5050 */ - /** @} */ /* END CUDA_MODULE */ @@ -4573,12 +5990,12 @@ cuLinkDestroy(CUlinkState state); * @{ */ -#if __CUDA_API_VERSION >= 3020 /** * \brief Gets free and total memory * - * Returns in \p *free and \p *total respectively, the free and total amount of - * memory available for allocation by the CUDA context, in bytes. + * Returns in \p *total the total amount of memory available to the the current context. + * Returns in \p *free the amount of memory on the device that is free according to the OS. + * CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free. * * \param free - Returned free memory in bytes * \param total - Returned total memory in bytes @@ -4811,7 +6228,6 @@ CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, CUdevic * ::cudaMallocHost */ CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize); -#endif /* __CUDA_API_VERSION >= 3020 */ /** * \brief Frees page-locked host memory @@ -4878,9 +6294,6 @@ CUresult CUDAAPI cuMemFreeHost(void *p); * All of these flags are orthogonal to one another: a developer may allocate * memory that is portable, mapped and/or write-combined with no restrictions. * - * The CUDA context must have been created with the ::CU_CTX_MAP_HOST flag in - * order for the ::CU_MEMHOSTALLOC_DEVICEMAP flag to have any effect. - * * The ::CU_MEMHOSTALLOC_DEVICEMAP flag may be specified on CUDA contexts for * devices that do not support mapped pinned memory. The failure is deferred * to ::cuMemHostGetDevicePointer() because the memory may be mapped into @@ -4925,7 +6338,6 @@ CUresult CUDAAPI cuMemFreeHost(void *p); */ CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, unsigned int Flags); -#if __CUDA_API_VERSION >= 3020 /** * \brief Passes back device pointer of mapped pinned memory * @@ -4978,7 +6390,6 @@ CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, unsigned int Flags); * ::cudaHostGetDevicePointer */ CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, unsigned int Flags); -#endif /* __CUDA_API_VERSION >= 3020 */ /** * \brief Passes back flags that were used for a pinned allocation @@ -5007,8 +6418,6 @@ CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, unsigned */ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p); -#if __CUDA_API_VERSION >= 6000 - /** * \brief Allocates memory that will be automatically managed by the Unified Memory system * @@ -5119,10 +6528,6 @@ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p); */ CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, unsigned int flags); -#endif /* __CUDA_API_VERSION >= 6000 */ - -#if __CUDA_API_VERSION >= 4010 - /** * \brief Returns a handle to a compute device * @@ -5201,8 +6606,7 @@ CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev); * * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. - * IPC functionality on Windows is restricted to GPUs in TCC mode. - * IPC functionality is not supported on Tegra platforms. + * IPC functionality on Windows is restricted to GPUs in TCC mode * * \param pHandle - Pointer to a user allocated CUipcEventHandle * in which to return the opaque event handle @@ -5243,8 +6647,7 @@ CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event); * * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. - * IPC functionality on Windows is restricted to GPUs in TCC mode. - * IPC functionality is not supported on Tegra platforms. + * IPC functionality on Windows is restricted to GPUs in TCC mode * * \param phEvent - Returns the imported event * \param handle - Interprocess handle to open @@ -5287,8 +6690,7 @@ CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, CUipcEventHandle handle) * * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. - * IPC functionality on Windows is restricted to GPUs in TCC mode. - * IPC functionality is not supported on Tegra platforms. + * IPC functionality on Windows is restricted to GPUs in TCC mode * * \param pHandle - Pointer to user allocated ::CUipcMemHandle to return * the handle in. @@ -5297,6 +6699,7 @@ CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, CUipcEventHandle handle) * \returns * ::CUDA_SUCCESS, * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_OUT_OF_MEMORY, * ::CUDA_ERROR_MAP_FAILED, * ::CUDA_ERROR_INVALID_VALUE @@ -5327,6 +6730,10 @@ CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); * ::CUipcMemHandles from each ::CUdevice in a given process may only be opened * by one ::CUcontext per ::CUdevice per other process. * + * If the memory handle has already been opened by the current context, the + * reference count on the handle is incremented by 1 and the existing device pointer + * is returned. + * * Memory returned from ::cuIpcOpenMemHandle must be freed with * ::cuIpcCloseMemHandle. * @@ -5336,8 +6743,7 @@ CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); * * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. - * IPC functionality on Windows is restricted to GPUs in TCC mode. - * IPC functionality is not supported on Tegra platforms. + * IPC functionality on Windows is restricted to GPUs in TCC mode * * \param pdptr - Returned device pointer * \param handle - ::CUipcMemHandle to open @@ -5368,9 +6774,10 @@ CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int Flags); /** - * \brief Close memory mapped with ::cuIpcOpenMemHandle + * \brief Attempts to close memory mapped with ::cuIpcOpenMemHandle * - * Unmaps memory returnd by ::cuIpcOpenMemHandle. The original allocation + * Decrements the reference count of the memory returned by ::cuIpcOpenMemHandle by 1. + * When the reference count reaches 0, this API unmaps the memory. The original allocation * in the exporting process as well as imported mappings in other processes * will be unaffected. * @@ -5379,8 +6786,7 @@ CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, u * * IPC functionality is restricted to devices with support for unified * addressing on Linux and Windows operating systems. - * IPC functionality on Windows is restricted to GPUs in TCC mode. - * IPC functionality is not supported on Tegra platforms. + * IPC functionality on Windows is restricted to GPUs in TCC mode * * \param dptr - Device pointer returned by ::cuIpcOpenMemHandle * @@ -5401,9 +6807,6 @@ CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, u */ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); -#endif /* __CUDA_API_VERSION >= 4010 */ - -#if __CUDA_API_VERSION >= 4000 /** * \brief Registers an existing host memory range for use by CUDA * @@ -5420,9 +6823,6 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); * * This function has limited support on Mac OS X. OS 10.7 or higher is required. * - * This function is supported only on I/O coherent devices that have a non-zero value - * for the device attribute ::CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED. - * * The \p Flags parameter enables different options to be specified that * affect the allocation, as follows. * @@ -5437,12 +6837,18 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); * - ::CU_MEMHOSTREGISTER_IOMEMORY: The pointer is treated as pointing to some * I/O memory space, e.g. the PCI Express resource of a 3rd party device. * + * - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory + * that is considered read-only by the device. On platforms without + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is + * required in order to register memory mapped to the CPU as read-only. Support + * for the use of this flag can be queried from the device attribute + * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with + * a current context associated with a device that does not have this attribute + * set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. + * * All of these flags are orthogonal to one another: a developer may page-lock * memory that is portable or mapped with no restrictions. * - * The CUDA context must have been created with the ::CU_CTX_MAP_HOST flag in - * order for the ::CU_MEMHOSTREGISTER_DEVICEMAP flag to have any effect. - * * The ::CU_MEMHOSTREGISTER_DEVICEMAP flag may be specified on CUDA contexts for * devices that do not support mapped pinned memory. The failure is deferred * to ::cuMemHostGetDevicePointer() because the memory may be mapped into @@ -5538,6 +6944,7 @@ CUresult CUDAAPI cuMemHostUnregister(void *p); * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * \note_sync + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -5585,9 +6992,6 @@ CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount); */ CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount); -#endif /* __CUDA_API_VERSION >= 4000 */ - -#if __CUDA_API_VERSION >= 3020 /** * \brief Copies memory from Host to Device * @@ -5607,6 +7011,7 @@ CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, CUdev * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * \note_sync + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -5642,6 +7047,7 @@ CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, size_t * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * \note_sync + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -5789,6 +7195,7 @@ CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, size_t sr * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * \note_sync + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -5825,6 +7232,7 @@ CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, const void *sr * ::CUDA_ERROR_INVALID_VALUE * \notefnerr * \note_sync + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -6374,9 +7782,7 @@ CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy); * ::cudaMemcpy3D */ CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy); -#endif /* __CUDA_API_VERSION >= 3020 */ -#if __CUDA_API_VERSION >= 4000 /** * \brief Copies memory between contexts * @@ -6426,6 +7832,7 @@ CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy); * \notefnerr * \note_async * \note_null_stream + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -6477,9 +7884,7 @@ CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, size_t ByteCoun * ::cudaMemcpyPeerAsync */ CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount, CUstream hStream); -#endif /* __CUDA_API_VERSION >= 4000 */ -#if __CUDA_API_VERSION >= 3020 /** * \brief Copies memory from Host to Device * @@ -6502,6 +7907,7 @@ CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, * \notefnerr * \note_async * \note_null_stream + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -6542,6 +7948,7 @@ CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, s * \notefnerr * \note_async * \note_null_stream + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -6625,6 +8032,7 @@ CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, * \notefnerr * \note_async * \note_null_stream + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -6666,6 +8074,7 @@ CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, const voi * \notefnerr * \note_async * \note_null_stream + * \note_memcpy * * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, @@ -7025,9 +8434,7 @@ CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream); * ::cudaMemcpy3DAsync */ CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream); -#endif /* __CUDA_API_VERSION >= 3020 */ -#if __CUDA_API_VERSION >= 4000 /** * \brief Copies memory between contexts asynchronously. * @@ -7053,9 +8460,7 @@ CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream); * ::cudaMemcpy3DPeerAsync */ CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, CUstream hStream); -#endif /* __CUDA_API_VERSION >= 4000 */ -#if __CUDA_API_VERSION >= 3020 /** * \brief Initializes device memory * @@ -7171,7 +8576,7 @@ CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N); * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param uc - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7212,7 +8617,7 @@ CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, unsigned c * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param us - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7253,7 +8658,7 @@ CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, unsigned * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param ui - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7403,7 +8808,7 @@ CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, size_t * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param uc - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7446,7 +8851,7 @@ CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, unsig * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param us - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7489,7 +8894,7 @@ CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * ::cuMemAllocPitch(). * * \param dstDevice - Destination device pointer - * \param dstPitch - Pitch of destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) * \param ui - Value to set * \param Width - Width of row * \param Height - Number of rows @@ -7582,7 +8987,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * float16's: * \code CUDA_ARRAY_DESCRIPTOR desc; - desc.FormatFlags = CU_AD_FORMAT_HALF; + desc.Format = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -7592,7 +8997,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * of which is two 8-bit unsigned chars: * \code CUDA_ARRAY_DESCRIPTOR arrayDesc; - desc.FormatFlags = CU_AD_FORMAT_UNSIGNED_INT8; + desc.Format = CU_AD_FORMAT_UNSIGNED_INT8; desc.NumChannels = 2; desc.Width = width; desc.Height = height; @@ -7658,8 +9063,88 @@ CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, const CUDA_ARRAY_DESCRIPTOR *pA * ::cudaArrayGetInfo */ CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, CUarray hArray); -#endif /* __CUDA_API_VERSION >= 3020 */ +/** + * \brief Returns the layout properties of a sparse CUDA array + * + * Returns the layout properties of a sparse CUDA array in \p sparseProperties + * If the CUDA array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * If the returned value in ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, + * then ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize represents the total size of the array. Otherwise, it will be zero. + * Also, the returned value in ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is always zero. + * Note that the \p array must have been allocated using ::cuArrayCreate or ::cuArray3DCreate. For CUDA arrays obtained + * using ::cuMipmappedArrayGetLevel, ::CUDA_ERROR_INVALID_VALUE will be returned. Instead, ::cuMipmappedArrayGetSparseProperties + * must be used to obtain the sparse properties of the entire CUDA mipmapped array to which \p array belongs to. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] array - CUDA array to get the sparse properties of + * \sa ::cuMipmappedArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUarray array); + +/** + * \brief Returns the layout properties of a sparse CUDA mipmapped array + * + * Returns the sparse array layout properties in \p sparseProperties + * If the CUDA mipmapped array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * For non-layered CUDA mipmapped arrays, ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize returns the + * size of the mip tail region. The mip tail region includes all mip levels whose width, height or depth + * is less than that of the tile. + * For layered CUDA mipmapped arrays, if ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, + * then ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies the size of the mip tail of all layers combined. + * Otherwise, ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies mip tail size per layer. + * The returned value of ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is valid only if ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize is non-zero. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] mipmap - CUDA mipmapped array to get the sparse properties of + * \sa ::cuArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuMipmappedArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUmipmappedArray mipmap); + +/** + * \brief Gets a CUDA array plane from a CUDA array + * + * Returns in \p pPlaneArray a CUDA array that represents a single format plane + * of the CUDA array \p hArray. + * + * If \p planeIdx is greater than the maximum number of planes in this array or if the array does + * not have a multi-planar format e.g: ::CU_AD_FORMAT_NV12, then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Note that if the \p hArray has format ::CU_AD_FORMAT_NV12, then passing in 0 for \p planeIdx returns + * a CUDA array of the same size as \p hArray but with one channel and ::CU_AD_FORMAT_UNSIGNED_INT8 as its format. + * If 1 is passed for \p planeIdx, then the returned CUDA array has half the height and width + * of \p hArray with two channels and ::CU_AD_FORMAT_UNSIGNED_INT8 as its format. + * + * \param pPlaneArray - Returned CUDA array referenced by the \p planeIdx + * \param hArray - Multiplanar CUDA array + * \param planeIdx - Plane index + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::cuArrayCreate, + * ::cudaGetArrayPlane + */ +CUresult CUDAAPI cuArrayGetPlane(CUarray *pPlaneArray, CUarray hArray, unsigned int planeIdx); /** * \brief Destroys a CUDA array @@ -7692,7 +9177,6 @@ CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, C */ CUresult CUDAAPI cuArrayDestroy(CUarray hArray); -#if __CUDA_API_VERSION >= 3020 /** * \brief Creates a 3D CUDA array * @@ -7839,7 +9323,7 @@ CUresult CUDAAPI cuArrayDestroy(CUarray hArray); * 4x16-bit float16's: * \code CUDA_ARRAY3D_DESCRIPTOR desc; - desc.FormatFlags = CU_AD_FORMAT_HALF; + desc.Format = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -7910,9 +9394,6 @@ CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR * ::cudaArrayGetInfo */ CUresult CUDAAPI cuArray3DGetDescriptor(CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray); -#endif /* __CUDA_API_VERSION >= 3020 */ - -#if __CUDA_API_VERSION >= 5000 /** * \brief Creates a CUDA mipmapped array @@ -8111,10 +9592,932 @@ CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, CUmipmappedArray */ CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); -#endif /* __CUDA_API_VERSION >= 5000 */ - /** @} */ /* END CUDA_MEM */ +/** + * \defgroup CUDA_VA Virtual Memory Management + * + * ___MANBRIEF___ virtual memory management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the virtual memory management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** +* \brief Allocate an address range reservation. +* +* Reserves a virtual address range based on the given parameters, giving +* the starting address of the range in \p ptr. This API requires a system that +* supports UVA. The size and address parameters must be a multiple of the +* host page size and the alignment must be a power of two or zero for default +* alignment. +* +* \param[out] ptr - Resulting pointer to start of virtual address range allocated +* \param[in] size - Size of the reserved virtual address range requested +* \param[in] alignment - Alignment of the reserved virtual address range requested +* \param[in] addr - Fixed starting address range requested +* \param[in] flags - Currently unused, must be zero +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemAddressFree +*/ +CUresult CUDAAPI cuMemAddressReserve(CUdeviceptr *ptr, size_t size, size_t alignment, CUdeviceptr addr, unsigned long long flags); + +/** +* \brief Free an address range reservation. +* +* Frees a virtual address range reserved by cuMemAddressReserve. The size +* must match what was given to memAddressReserve and the ptr given must +* match what was returned from memAddressReserve. +* +* \param[in] ptr - Starting address of the virtual address range to free +* \param[in] size - Size of the virtual address region to free +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemAddressReserve +*/ +CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size); + +/** +* \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties +* +* This creates a memory allocation on the target device specified through the +* \p prop strcuture. The created allocation will not have any device or host +* mappings. The generic memory \p handle for the allocation can be +* mapped to the address space of calling process via ::cuMemMap. This handle +* cannot be transmitted directly to other processes (see +* ::cuMemExportToShareableHandle). On Windows, the caller must also pass +* an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which +* limits or allows access to this handle for a recepient process (see +* ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this +* allocation must be a multiple of the the value given via +* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM +* flag. +* If ::CUmemAllocationProp::allocFlags::usage contains ::CU_MEM_CREATE_USAGE_TILE_POOL flag then +* the memory allocation is intended only to be used as backing tile pool for sparse CUDA arrays +* and sparse CUDA mipmapped arrays. +* (see ::cuMemMapArrayAsync). +* +* \param[out] handle - Value of handle returned. All operations on this allocation are to be performed using this handle. +* \param[in] size - Size of the allocation requested +* \param[in] prop - Properties of the allocation to create. +* \param[in] flags - flags for future use, must be zero now. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemRelease, ::cuMemExportToShareableHandle, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, const CUmemAllocationProp *prop, unsigned long long flags); + +/** +* \brief Release a memory handle representing a memory allocation which was previously allocated through cuMemCreate. +* +* Frees the memory that was allocated on a device through cuMemCreate. +* +* The memory allocation will be freed when all outstanding mappings to the memory +* are unmapped and when all outstanding references to the handle (including it's +* shareable counterparts) are also released. The generic memory handle can be +* freed when there are still outstanding mappings made with this handle. Each +* time a recepient process imports a shareable handle, it needs to pair it with +* ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle +* the behavior is undefined. +* +* \param[in] handle Value of handle which was returned previously by cuMemCreate. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemCreate +*/ +CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle); + +/** +* \brief Maps an allocation handle to a reserved virtual address range. +* +* Maps bytes of memory represented by \p handle starting from byte \p offset to +* \p size to address range [\p addr, \p addr + \p size]. This range must be an +* address reservation previously reserved with ::cuMemAddressReserve, and +* \p offset + \p size must be less than the size of the memory allocation. +* Both \p ptr, \p size, and \p offset must be a multiple of the value given via +* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM flag. +* +* Please note calling ::cuMemMap does not make the address accessible, +* the caller needs to update accessibility of a contiguous mapped VA +* range by calling ::cuMemSetAccess. +* +* Once a recipient process obtains a shareable memory handle +* from ::cuMemImportFromShareableHandle, the process must +* use ::cuMemMap to map the memory into its address ranges before +* setting accessibility with ::cuMemSetAccess. +* +* ::cuMemMap can only create mappings on VA range reservations +* that are not currently mapped. +* +* \param[in] ptr - Address where memory will be mapped. +* \param[in] size - Size of the memory mapping. +* \param[in] offset - Offset into the memory represented by +* - \p handle from which to start mapping +* - Note: currently must be zero. +* \param[in] handle - Handle to a shareable memory +* \param[in] flags - flags for future use, must be zero now. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemUnmap, ::cuMemSetAccess, ::cuMemCreate, ::cuMemAddressReserve, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, CUmemGenericAllocationHandle handle, unsigned long long flags); + +/** + * \brief Maps or unmaps subregions of sparse CUDA arrays and sparse CUDA mipmapped arrays + * + * Performs map or unmap operations on subregions of sparse CUDA arrays and sparse CUDA mipmapped arrays. + * Each operation is specified by a ::CUarrayMapInfo entry in the \p mapInfoList array of size \p count. + * The structure ::CUarrayMapInfo is defined as follow: + \code + typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; + union { + struct { + unsigned int level; + unsigned int layer; + unsigned int offsetX; + unsigned int offsetY; + unsigned int offsetZ; + unsigned int extentWidth; + unsigned int extentHeight; + unsigned int extentDepth; + } sparseLevel; + struct { + unsigned int layer; + unsigned long long offset; + unsigned long long size; + } miptail; + } subresource; + + CUmemOperationType memOperationType; + + CUmemHandleType memHandleType; + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; + unsigned int deviceBitMask; + unsigned int flags; + unsigned int reserved[2]; + } CUarrayMapInfo; + \endcode + * + * where ::CUarrayMapInfo::resourceType specifies the type of resource to be operated on. + * If ::CUarrayMapInfo::resourceType is set to ::CUresourcetype::CU_RESOURCE_TYPE_ARRAY then + * ::CUarrayMapInfo::resource::array must be set to a valid sparse CUDA array handle. + * The CUDA array must be either a 2D, 2D layered or 3D CUDA array and must have been allocated using + * ::cuArrayCreate or ::cuArray3DCreate with the flag ::CUDA_ARRAY3D_SPARSE. + * For CUDA arrays obtained using ::cuMipmappedArrayGetLevel, ::CUDA_ERROR_INVALID_VALUE will be returned. + * If ::CUarrayMapInfo::resourceType is set to ::CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY + * then ::CUarrayMapInfo::resource::mipmap must be set to a valid sparse CUDA mipmapped array handle. + * The CUDA mipmapped array must be either a 2D, 2D layered or 3D CUDA mipmapped array and must have been + * allocated using ::cuMipmappedArrayCreate with the flag ::CUDA_ARRAY3D_SPARSE. + * + * ::CUarrayMapInfo::subresourceType specifies the type of subresource within the resource. + * ::CUarraySparseSubresourceType_enum is defined as: + \code + typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 + } CUarraySparseSubresourceType; + \endcode + * + * where ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL indicates a + * sparse-miplevel which spans at least one tile in every dimension. The remaining miplevels which + * are too small to span at least one tile in any dimension constitute the mip tail region as indicated by + * ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL subresource type. + * + * If ::CUarrayMapInfo::subresourceType is set to ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL + * then ::CUarrayMapInfo::subresource::sparseLevel struct must contain valid array subregion offsets and extents. + * The ::CUarrayMapInfo::subresource::sparseLevel::offsetX, ::CUarrayMapInfo::subresource::sparseLevel::offsetY + * and ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must specify valid X, Y and Z offsets respectively. + * The ::CUarrayMapInfo::subresource::sparseLevel::extentWidth, ::CUarrayMapInfo::subresource::sparseLevel::extentHeight + * and ::CUarrayMapInfo::subresource::sparseLevel::extentDepth must specify valid width, height and depth extents respectively. + * These offsets and extents must be aligned to the corresponding tile dimension. + * For CUDA mipmapped arrays ::CUarrayMapInfo::subresource::sparseLevel::level must specify a valid mip level index. Otherwise, + * must be zero. + * For layered CUDA arrays and layered CUDA mipmapped arrays ::CUarrayMapInfo::subresource::sparseLevel::layer must specify a valid layer index. Otherwise, + * must be zero. + * ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must be zero and ::CUarrayMapInfo::subresource::sparseLevel::extentDepth + * must be set to 1 for 2D and 2D layered CUDA arrays and CUDA mipmapped arrays. + * Tile extents can be obtained by calling ::cuArrayGetSparseProperties and ::cuMipmappedArrayGetSparseProperties + * + * If ::CUarrayMapInfo::subresourceType is set to ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL + * then ::CUarrayMapInfo::subresource::miptail struct must contain valid mip tail offset in + * ::CUarrayMapInfo::subresource::miptail::offset and size in ::CUarrayMapInfo::subresource::miptail::size. + * Both, mip tail offset and mip tail size must be aligned to the tile size. + * For layered CUDA mipmapped arrays which don't have the flag ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL set in ::CUDA_ARRAY_SPARSE_PROPERTIES::flags + * as returned by ::cuMipmappedArrayGetSparseProperties, ::CUarrayMapInfo::subresource::miptail::layer must specify a valid layer index. + * Otherwise, must be zero. + * + * ::CUarrayMapInfo::memOperationType specifies the type of operation. ::CUmemOperationType is defined as: + \code + typedef enum CUmemOperationType_enum { + CU_MEM_OPERATION_TYPE_MAP = 1, + CU_MEM_OPERATION_TYPE_UNMAP = 2 + } CUmemOperationType; + \endcode + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP then the subresource + * will be mapped onto the tile pool memory specified by ::CUarrayMapInfo::memHandle at offset ::CUarrayMapInfo::offset. + * The tile pool allocation has to be created by specifying the ::CU_MEM_CREATE_USAGE_TILE_POOL flag when calling ::cuMemCreate. Also, + * ::CUarrayMapInfo::memHandleType must be set to ::CUmemHandleType::CU_MEM_HANDLE_TYPE_GENERIC. + * + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_UNMAP then an unmapping operation + * is performed. ::CUarrayMapInfo::memHandle must be NULL. + * + * ::CUarrayMapInfo::deviceBitMask specifies the list of devices that must map or unmap physical memory. + * Currently, this mask must have exactly one bit set, and the corresponding device must match the device associated with the stream. + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP, the device must also match + * the device associated with the tile pool memory allocation as specified by ::CUarrayMapInfo::memHandle. + * + * ::CUarrayMapInfo::flags and ::CUarrayMapInfo::reserved[] are unused and must be set to zero. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \param[in] mapInfoList - List of ::CUarrayMapInfo + * \param[in] count - Count of ::CUarrayMapInfo in \p mapInfoList + * \param[in] hStream - Stream identifier for the stream to use for map or unmap operations + * + * \sa ::cuMipmappedArrayCreate, ::cuArrayCreate, ::cuArray3DCreate, ::cuMemCreate, ::cuArrayGetSparseProperties, ::cuMipmappedArrayGetSparseProperties + */ +CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, unsigned int count, CUstream hStream); + +/** +* \brief Unmap the backing memory of a given address range. +* +* The range must be the entire contiguous address range that was mapped to. In +* other words, ::cuMemUnmap cannot unmap a sub-range of an address range mapped +* by ::cuMemCreate / ::cuMemMap. Any backing memory allocations will be freed +* if there are no existing mappings and there are no unreleased memory handles. +* +* When ::cuMemUnmap returns successfully the address range is converted to an +* address reservation and can be used for a future calls to ::cuMemMap. Any new +* mapping to this virtual address will need to have access granted through +* ::cuMemSetAccess, as all mappings start with no accessibility setup. +* +* \param[in] ptr - Starting address for the virtual address range to unmap +* \param[in] size - Size of the virtual address range to unmap +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* \note_sync +* +* \sa ::cuMemCreate, ::cuMemAddressReserve +*/ +CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size); + +/** +* \brief Set the access flags for each location specified in \p desc for the given virtual address range +* +* Given the virtual address range via \p ptr and \p size, and the locations +* in the array given by \p desc and \p count, set the access flags for the +* target locations. The range must be a fully mapped address range +* containing all allocations created by ::cuMemMap / ::cuMemCreate. +* +* \param[in] ptr - Starting address for the virtual address range +* \param[in] size - Length of the virtual address range +* \param[in] desc - Array of ::CUmemAccessDesc that describe how to change the +* - mapping for each location specified +* \param[in] count - Number of ::CUmemAccessDesc in \p desc +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* \note_sync +* +* \sa ::cuMemSetAccess, ::cuMemCreate, :cuMemMap +*/ +CUresult CUDAAPI cuMemSetAccess(CUdeviceptr ptr, size_t size, const CUmemAccessDesc *desc, size_t count); + +/** +* \brief Get the access \p flags set for the given \p location and \p ptr +* +* \param[out] flags - Flags set for this location +* \param[in] location - Location in which to check the flags for +* \param[in] ptr - Address in which to check the access flags for +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemSetAccess +*/ +CUresult CUDAAPI cuMemGetAccess(unsigned long long *flags, const CUmemLocation *location, CUdeviceptr ptr); + +/** +* \brief Exports an allocation to a requested shareable handle type +* +* Given a CUDA memory handle, create a shareable memory +* allocation handle that can be used to share the memory with other +* processes. The recipient process can convert the shareable handle back into a +* CUDA memory handle using ::cuMemImportFromShareableHandle and map +* it with ::cuMemMap. The implementation of what this handle is and how it +* can be transferred is defined by the requested handle type in \p handleType +* +* Once all shareable handles are closed and the allocation is released, the allocated +* memory referenced will be released back to the OS and uses of the CUDA handle afterward +* will lead to undefined behavior. +* +* This API can also be used in conjunction with other APIs (e.g. Vulkan, OpenGL) +* that support importing memory from the shareable type +* +* \param[out] shareableHandle - Pointer to the location in which to store the requested handle type +* \param[in] handle - CUDA handle for the memory allocation +* \param[in] handleType - Type of shareable handle requested (defines type and size of the \p shareableHandle output parameter) +* \param[in] flags - Reserved, must be zero +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemExportToShareableHandle(void *shareableHandle, CUmemGenericAllocationHandle handle, CUmemAllocationHandleType handleType, unsigned long long flags); + +/** +* \brief Imports an allocation from a requested shareable handle type. +* +* If the current process cannot support the memory described by this shareable +* handle, this API will error as CUDA_ERROR_NOT_SUPPORTED. +* +* \note Importing shareable handles exported from some graphics APIs(VUlkan, OpenGL, etc) +* created on devices under an SLI group may not be supported, and thus this API will +* return CUDA_ERROR_NOT_SUPPORTED. +* There is no guarantee that the contents of \p handle will be the same CUDA memory handle +* for the same given OS shareable handle, or the same underlying allocation. +* +* \param[out] handle - CUDA Memory handle for the memory allocation. +* \param[in] osHandle - Shareable Handle representing the memory allocation that is to be imported. +* \param[in] shHandleType - handle type of the exported handle ::CUmemAllocationHandleType. +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemExportToShareableHandle, ::cuMemMap, ::cuMemRelease +*/ +CUresult CUDAAPI cuMemImportFromShareableHandle(CUmemGenericAllocationHandle *handle, void *osHandle, CUmemAllocationHandleType shHandleType); + +/** +* \brief Calculates either the minimal or recommended granularity +* +* Calculates either the minimal or recommended granularity +* for a given allocation specification and returns it in granularity. This +* granularity can be used as a multiple for alignment, size, or address mapping. +* +* \param[out] granularity Returned granularity. +* \param[in] prop Property for which to determine the granularity for +* \param[in] option Determines which granularity to return +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemMap +*/ +CUresult CUDAAPI cuMemGetAllocationGranularity(size_t *granularity, const CUmemAllocationProp *prop, CUmemAllocationGranularity_flags option); + +/** +* \brief Retrieve the contents of the property structure defining properties for this handle +* +* \param[out] prop - Pointer to a properties structure which will hold the information about this handle +* \param[in] handle - Handle which to perform the query on +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemGetAllocationPropertiesFromHandle(CUmemAllocationProp *prop, CUmemGenericAllocationHandle handle); + +/** +* \brief Given an address \p addr, returns the allocation handle of the backing memory allocation. +* +* The handle is guaranteed to be the same handle value used to map the memory. If the address +* requested is not mapped, the function will fail. The returned handle must be released with +* corresponding number of calls to ::cuMemRelease. +* +* \note The address \p addr, can be any address in a range previously mapped +* by ::cuMemMap, and not necessarily the start address. +* +* \param[out] handle CUDA Memory handle for the backing memory allocation. +* \param[in] addr Memory address to query, that has been mapped previously. +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemRelease, ::cuMemMap +*/ +CUresult CUDAAPI cuMemRetainAllocationHandle(CUmemGenericAllocationHandle *handle, void *addr); + +/** @} */ /* END CUDA_VA */ + +/** + * \defgroup CUDA_MALLOC_ASYNC Stream Ordered Memory Allocator + * + * ___MANBRIEF___ Functions for performing allocation and free operations in stream order. + * Functions for controlling the behavior of the underlying allocator. + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream ordered memory allocator exposed by the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_MALLOC_ASYNC_overview overview + * + * The asynchronous allocator allows the user to allocate and free in stream order. + * All asynchronous accesses of the allocation must happen between + * the stream executions of the allocation and the free. If the memory is accessed + * outside of the promised stream order, a use before allocation / use after free error + * will cause undefined behavior. + * + * The allocator is free to reallocate the memory as long as it can guarantee + * that compliant memory accesses will not overlap temporally. + * The allocator may refer to internal stream ordering as well as inter-stream dependencies + * (such as CUDA events and null stream dependencies) when establishing the temporal guarantee. + * The allocator may also insert inter-stream dependencies to establish the temporal guarantee. + * + * \section CUDA_MALLOC_ASYNC_support Supported Platforms + * + * Whether or not a device supports the integrated stream ordered memory allocator + * may be queried by calling ::cuDeviceGetAttribute() with the device attribute + * ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED + */ + +/** + * \brief Frees memory with stream ordered semantics + * + * Inserts a free operation into \p hStream. + * The allocation must not be accessed after stream execution reaches the free. + * After this API returns, accessing the memory from any subsequent work launched on the GPU + * or querying its pointer attributes results in undefined behavior. + * + * \note During stream capture, this function results in the creation of a free node and + * must therefore be passed the address of a graph allocation. + * + * \param dptr - memory to free + * \param hStream - The stream establishing the stream ordering contract. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED + */ +CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); + +/** + * \brief Allocates memory with stream ordered semantics + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation completes. + * The allocation comes from the memory pool current to the stream's device. + * + * \note The default memory pool of a device contains device memory from that device. + * \note Basic stream ordering allows future work submitted into the same stream to use the allocation. + * Stream query, stream synchronize, and CUDA events can be used to guarantee that the allocation + * operation completes before work submitted in a separate stream runs. + * \note During stream capture, this function results in the creation of an allocation node. In this case, + * the allocation is owned by the graph instead of the memory pool. The memory pool's properties + * are used to set the node's creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] hStream - The stream establishing the stream ordering contract and the memory pool to allocate from + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocFromPoolAsync, ::cuMemFreeAsync, ::cuDeviceSetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, CUstream hStream); + +/** + * \brief Tries to release memory back to the OS + * + * Releases memory back to the OS until the pool contains fewer than minBytesToKeep + * reserved bytes, or there is no more memory that the allocator can safely release. + * The allocator cannot release OS allocations that back outstanding asynchronous allocations. + * The OS allocations may happen at different granularity from the user allocations. + * + * \note: Allocations that have not been freed count as outstanding. + * \note: Allocations that have been asynchronously freed but whose completion has + * not been observed on the host (eg. by a synchronize) can count as outstanding. + * + * \param[in] pool - The memory pool to trim + * \param[in] minBytesToKeep - If the pool has less than minBytesToKeep reserved, + * the TrimTo operation is a no-op. Otherwise the pool will be guaranteed to have + * at least minBytesToKeep bytes reserved after the operation. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolTrimTo(CUmemoryPool pool, size_t minBytesToKeep); + +/** + * \brief Sets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of backing memory that was + * allocated for the memory pool. It is illegal to set this attribute to a non-zero value. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of used memory that was + * allocated for the memory pool. + * + * \param[in] pool - The memory pool to modify + * \param[in] attr - The attribute to modify + * \param[in] value - Pointer to the value to assign + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAttribute(CUmemoryPool pool, CUmemPool_attribute attr, void *value); + +/** + * \brief Gets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of backing memory currently allocated for the mempool + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. + * - ::CU_MEMPOOL_ATTR_USED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by the application. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that was in use by the application. + * + * \param[in] pool - The memory pool to get attributes of + * \param[in] attr - The attribute to get + * \param[out] value - Retrieved value + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAttribute(CUmemoryPool pool, CUmemPool_attribute attr, void *value); + +/** + * \brief Controls visibility of pools between devices + * + * \param[in] pool - The pool being modified + * \param[in] map - Array of access descriptors. Each descriptor instructs the access to enable for a single gpu. + * \param[in] count - Number of descriptors in the map array. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAccess(CUmemoryPool pool, const CUmemAccessDesc *map, size_t count); + +/** + * \brief Returns the accessibility of a pool from a device + * + * Returns the accessibility of the pool's memory from the specified location. + * + * \param[out] flags - the accessibility of the pool from the specified location + * \param[in] memPool - the pool being queried + * \param[in] location - the location accessing the pool + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAccess(CUmemAccess_flags *flags, CUmemoryPool memPool, CUmemLocation *location); + +/** + * \brief Creates a memory pool + * + * Creates a CUDA memory pool and returns the handle in \p pool. The \p poolProps determines + * the properties of the pool such as the backing device and IPC capabilities. + * + * By default, the pool's memory will be accessible from the device it is allocated on. + * + * \note Specifying CU_MEM_HANDLE_TYPE_NONE creates a memory pool that will not support IPC. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, ::cuDeviceGetDefaultMemPool, + * ::cuMemAllocFromPoolAsync, ::cuMemPoolExportToShareableHandle + */ +CUresult CUDAAPI cuMemPoolCreate(CUmemoryPool *pool, const CUmemPoolProps *poolProps); + +/** + * \brief Destroys the specified memory pool + * + * If any pointers obtained from this pool haven't been freed or + * the pool has free operations that haven't completed + * when ::cuMemPoolDestroy is invoked, the function will return immediately and the + * resources associated with the pool will be released automatically + * once there are no more outstanding allocations. + * + * Destroying the current mempool of a device sets the default mempool of + * that device as the current mempool for that device. + * + * \note A device's default memory pool cannot be destroyed. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemFreeAsync, ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolDestroy(CUmemoryPool pool); + +/** + * \brief Allocates memory from a specified pool with stream ordered semantics. + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation completes. + * The allocation comes from the specified memory pool. + * + * \note + * - The specified memory pool may be from a device different than that of the specified \p hStream. + * + * - Basic stream ordering allows future work submitted into the same stream to use the allocation. + * Stream query, stream synchronize, and CUDA events can be used to guarantee that the allocation + * operation completes before work submitted in a separate stream runs. + * + * \note During stream capture, this function results in the creation of an allocation node. In this case, + * the allocation is owned by the graph instead of the memory pool. The memory pool's properties + * are used to set the node's creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] pool - The pool to allocate from + * \param[in] hStream - The stream establishing the stream ordering semantic + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate, ::cuMemPoolSetAccess, + * ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, CUmemoryPool pool, CUstream hStream); + +/** + * \brief Exports a memory pool to the requested handle type. + * + * Given an IPC capable mempool, create an OS handle to share the pool with another process. + * A recipient process can convert the shareable handle into a mempool with ::cuMemPoolImportFromShareableHandle. + * Individual pointers can then be shared with the ::cuMemPoolExportPointer and ::cuMemPoolImportPointer APIs. + * The implementation of what the shareable handle is and how it can be transferred is defined by the requested + * handle type. + * + * \note: To create an IPC capable mempool, create a mempool with a CUmemAllocationHandleType other than CU_MEM_HANDLE_TYPE_NONE. + * + * \param[out] handle_out - Returned OS handle + * \param[in] pool - pool to export + * \param[in] handleType - the type of handle to create + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolImportFromShareableHandle, ::cuMemPoolExportPointer, + * ::cuMemPoolImportPointer, ::cuMemAllocAsync, ::cuMemFreeAsync, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemPoolExportToShareableHandle(void *handle_out, CUmemoryPool pool, CUmemAllocationHandleType handleType, unsigned long long flags); + +/** + * \brief imports a memory pool from a shared handle. + * + * Specific allocations can be imported from the imported pool with cuMemPoolImportPointer. + * + * \note Imported memory pools do not support creating new allocations. + * As such imported memory pools may not be used in cuDeviceSetMemPool + * or ::cuMemAllocFromPoolAsync calls. + * + * \param[out] pool_out - Returned memory pool + * \param[in] handle - OS handle of the pool to open + * \param[in] handleType - The type of handle being imported + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolExportPointer, ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolImportFromShareableHandle( + CUmemoryPool *pool_out, + void *handle, + CUmemAllocationHandleType handleType, + unsigned long long flags); + +/** + * \brief Export data to share a memory pool allocation between processes. + * + * Constructs \p shareData_out for sharing a specific allocation from an already shared memory pool. + * The recipient process can import the allocation with the ::cuMemPoolImportPointer api. + * The data is not a handle and may be shared through any IPC mechanism. + * + * \param[out] shareData_out - Returned export data + * \param[in] ptr - pointer to memory being exported + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolExportPointer(CUmemPoolPtrExportData *shareData_out, CUdeviceptr ptr); + +/** + * \brief Import a memory pool allocation from another process. + * + * Returns in \p ptr_out a pointer to the imported memory. + * The imported memory must not be accessed before the allocation operation completes + * in the exporting process. The imported memory must be freed from all importing processes before + * being freed in the exporting process. The pointer may be freed with cuMemFree + * or cuMemFreeAsync. If cuMemFreeAsync is used, the free must be completed + * on the importing process before the free operation on the exporting process. + * + * \note The cuMemFreeAsync api may be used in the exporting process before + * the cuMemFreeAsync operation completes in its stream as long as the + * cuMemFreeAsync in the exporting process specifies a stream with + * a stream dependency on the importing process's cuMemFreeAsync. + * + * \param[out] ptr_out - pointer to imported memory + * \param[in] pool - pool from which to import + * \param[in] shareData - data specifying the memory to import + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, ::cuMemPoolExportPointer + */ +CUresult CUDAAPI cuMemPoolImportPointer(CUdeviceptr *ptr_out, CUmemoryPool pool, CUmemPoolPtrExportData *shareData); + +/** @} */ /* END CUDA_MALLOC_ASYNC */ + /** * \defgroup CUDA_UNIFIED Unified Addressing * @@ -8208,7 +10611,6 @@ CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); * */ -#if __CUDA_API_VERSION >= 4000 /** * \brief Returns information about a pointer * @@ -8314,11 +10716,47 @@ CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); * Returns in \p *data a boolean that indicates whether the pointer points to * managed memory or not. * + * If \p ptr is not a valid CUDA pointer then ::CUDA_ERROR_INVALID_VALUE is returned. + * * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL: * * Returns in \p *data an integer representing a device ordinal of a device against * which the memory was allocated or registered. * + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE: + * + * Returns in \p *data a boolean that indicates if this pointer maps to + * an allocation that is suitable for ::cudaIpcGetMemHandle. + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR: + * + * Returns in \p *data the starting address for the allocation referenced + * by the device pointer \p ptr. Note that this is not necessarily the + * address of the mapped region, but the address of the mappable address + * range \p ptr references (e.g. from ::cuMemAddressReserve). + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE: + * + * Returns in \p *data the size for the allocation referenced by the device + * pointer \p ptr. Note that this is not necessarily the size of the mapped + * region, but the size of the mappable address range \p ptr references + * (e.g. from ::cuMemAddressReserve). To retrieve the size of the mapped + * region, see ::cuMemGetAddressRange + * + * - ::CU_POINTER_ATTRIBUTE_MAPPED: + * + * Returns in \p *data a boolean that indicates if this pointer is in a + * valid address range that is mapped to a backing allocation. + * + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES: + * + * Returns a bitmask of the allowed handle types for an allocation that may + * be passed to ::cuMemExportToShareableHandle. + * + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE: + * + * Returns in \p *data the handle to the mempool that the allocation was obtained from. + * * \par * * Note that for most allocations in the unified virtual address space @@ -8362,9 +10800,7 @@ CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); * ::cudaPointerGetAttributes */ CUresult CUDAAPI cuPointerGetAttribute(void *data, CUpointer_attribute attribute, CUdeviceptr ptr); -#endif /* __CUDA_API_VERSION >= 4000 */ -#if __CUDA_API_VERSION >= 8000 /** * \brief Prefetches memory to the specified destination device * @@ -8639,14 +11075,12 @@ CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, CUmem_range * ::CUDA_ERROR_INVALID_DEVICE * \notefnerr * - * \sa ::cuMemRangeGetAttribute, ::cuMemAdvise + * \sa ::cuMemRangeGetAttribute, ::cuMemAdvise, * ::cuMemPrefetchAsync, * ::cudaMemRangeGetAttributes */ CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, CUmem_range_attribute *attributes, size_t numAttributes, CUdeviceptr devPtr, size_t count); -#endif /* __CUDA_API_VERSION >= 8000 */ -#if __CUDA_API_VERSION >= 6000 /** * \brief Set attributes on a previously allocated memory region * @@ -8688,9 +11122,7 @@ CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, CUmem_r * ::cuMemHostUnregister */ CUresult CUDAAPI cuPointerSetAttribute(const void *value, CUpointer_attribute attribute, CUdeviceptr ptr); -#endif /* __CUDA_API_VERSION >= 6000 */ -#if __CUDA_API_VERSION >= 7000 /** * \brief Returns information about a pointer. * @@ -8704,6 +11136,12 @@ CUresult CUDAAPI cuPointerSetAttribute(const void *value, CUpointer_attribute at * - ::CU_POINTER_ATTRIBUTE_BUFFER_ID * - ::CU_POINTER_ATTRIBUTE_IS_MANAGED * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE + * - ::CU_POINTER_ATTRIBUTE_MAPPED + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE * * \param numAttributes - Number of attributes to query * \param attributes - An array of attributes to query @@ -8733,7 +11171,6 @@ CUresult CUDAAPI cuPointerSetAttribute(const void *value, CUpointer_attribute at * ::cudaPointerGetAttributes */ CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, CUpointer_attribute *attributes, void **data, CUdeviceptr ptr); -#endif /* __CUDA_API_VERSION >= 7000 */ /** @} */ /* END CUDA_UNIFIED */ @@ -8753,7 +11190,9 @@ CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, CUpointer_at * \brief Create a stream * * Creates a stream and returns a handle in \p phStream. The \p Flags argument - * determines behaviors of the stream. Valid values for \p Flags are: + * determines behaviors of the stream. + * + * Valid values for \p Flags are: * - ::CU_STREAM_DEFAULT: Default stream creation flag. * - ::CU_STREAM_NON_BLOCKING: Specifies that work running in the created * stream may run concurrently with work in stream 0 (the NULL stream), and that @@ -8892,8 +11331,6 @@ CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); */ CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); -#if __CUDA_API_VERSION >= 9020 - /** * \brief Query the context associated with a stream * @@ -8938,8 +11375,6 @@ CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); */ CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); -#endif /* __CUDA_API_VERSION >= 9020 */ - /** * \brief Make a compute stream wait on an event * @@ -8948,9 +11383,15 @@ CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); * The synchronization will be performed efficiently on the device when applicable. * \p hEvent may be from a different context or device than \p hStream. * + * flags include: + * - ::CU_EVENT_WAIT_DEFAULT: Default event creation flag. + * - ::CU_EVENT_WAIT_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * * \param hStream - Stream to wait * \param hEvent - Event to wait on (may not be NULL) - * \param Flags - Parameters for the operation (must be 0) + * \param Flags - See ::CUevent_capture_flags * * \return * ::CUDA_SUCCESS, @@ -9046,8 +11487,6 @@ CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, unsigned in */ CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback, void *userData, unsigned int flags); -#if __CUDA_API_VERSION >= 10000 - /** * \brief Begins graph capture on a stream * @@ -9056,9 +11495,16 @@ CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback * a graph, which will be returned via ::cuStreamEndCapture. Capture may not be initiated * if \p stream is CU_STREAM_LEGACY. Capture must be ended on the same stream in which * it was initiated, and it may only be initiated if the stream is not already in capture - * mode. The capture mode may be queried via ::cuStreamIsCapturing. + * mode. The capture mode may be queried via ::cuStreamIsCapturing. A unique id + * representing the capture sequence may be queried via ::cuStreamGetCaptureInfo. + * + * If \p mode is not ::CU_STREAM_CAPTURE_MODE_RELAXED, ::cuStreamEndCapture must be + * called on this stream from the same thread. * * \param hStream - Stream in which to initiate capture + * \param mode - Controls the interaction of this capture sequence with other API + * calls that are potentially unsafe. For more details see + * ::cuThreadExchangeStreamCaptureMode. * * \note Kernels captured using this API must not use texture and surface references. * Reading or writing through any texture or surface reference is undefined @@ -9074,9 +11520,63 @@ CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback * \sa * ::cuStreamCreate, * ::cuStreamIsCapturing, - * ::cuStreamEndCapture + * ::cuStreamEndCapture, + * ::cuThreadExchangeStreamCaptureMode */ -CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, CUstreamCaptureMode mode); + +/** + * \brief Swaps the stream capture interaction mode for a thread + * + * Sets the calling thread's stream capture interaction mode to the value contained + * in \p *mode, and overwrites \p *mode with the previous mode for the thread. To + * facilitate deterministic behavior across function or module boundaries, callers + * are encouraged to use this API in a push-pop fashion: \code + CUstreamCaptureMode mode = desiredMode; + cuThreadExchangeStreamCaptureMode(&mode); + ... + cuThreadExchangeStreamCaptureMode(&mode); // restore previous mode + * \endcode + * + * During stream capture (see ::cuStreamBeginCapture), some actions, such as a call + * to ::cudaMalloc, may be unsafe. In the case of ::cudaMalloc, the operation is + * not enqueued asynchronously to a stream, and is not observed by stream capture. + * Therefore, if the sequence of operations captured via ::cuStreamBeginCapture + * depended on the allocation being replayed whenever the graph is launched, the + * captured graph would be invalid. + * + * Therefore, stream capture places restrictions on API calls that can be made within + * or concurrently to a ::cuStreamBeginCapture-::cuStreamEndCapture sequence. This + * behavior can be controlled via this API and flags to ::cuStreamBeginCapture. + * + * A thread's mode is one of the following: + * - \p CU_STREAM_CAPTURE_MODE_GLOBAL: This is the default mode. If the local thread has + * an ongoing capture sequence that was not initiated with + * \p CU_STREAM_CAPTURE_MODE_RELAXED at \p cuStreamBeginCapture, or if any other thread + * has a concurrent capture sequence initiated with \p CU_STREAM_CAPTURE_MODE_GLOBAL, + * this thread is prohibited from potentially unsafe API calls. + * - \p CU_STREAM_CAPTURE_MODE_THREAD_LOCAL: If the local thread has an ongoing capture + * sequence not initiated with \p CU_STREAM_CAPTURE_MODE_RELAXED, it is prohibited + * from potentially unsafe API calls. Concurrent capture sequences in other threads + * are ignored. + * - \p CU_STREAM_CAPTURE_MODE_RELAXED: The local thread is not prohibited from potentially + * unsafe API calls. Note that the thread is still prohibited from API calls which + * necessarily conflict with stream capture, for example, attempting ::cuEventQuery + * on an event that was last recorded inside a capture sequence. + * + * \param mode - Pointer to mode value to swap with the current mode + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamBeginCapture + */ +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode); /** * \brief Ends capture on a stream, returning the captured graph @@ -9086,6 +11586,10 @@ CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); * If capture was invalidated, due to a violation of the rules of stream capture, then * a NULL graph will be returned. * + * If the \p mode argument to ::cuStreamBeginCapture was not + * ::CU_STREAM_CAPTURE_MODE_RELAXED, this call must be from the same thread as + * ::cuStreamBeginCapture. + * * \param hStream - Stream to query * \param phGraph - The captured graph * @@ -9093,7 +11597,8 @@ CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); * ::CUDA_SUCCESS, * ::CUDA_ERROR_DEINITIALIZED, * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_VALUE + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD * \notefnerr * * \sa @@ -9143,9 +11648,120 @@ CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); */ CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); -#endif /* __CUDA_API_VERSION >= 10000 */ +/** + * \brief Query capture status of a stream + * + * Note there is a later version of this API, ::cuStreamGetCaptureInfo_v2. It will + * supplant this version in 12.0, which is retained for minor version compatibility. + * + * Query the capture status of a stream and and get an id for + * the capture sequence, which is unique over the lifetime of the process. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not created + * with ::CU_STREAM_NON_BLOCKING is capturing, returns ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * A valid id is returned only if both of the following are true: + * - the call returns CUDA_SUCCESS + * - captureStatus is set to ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo_v2, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing + */ +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); -#if __CUDA_API_VERSION >= 6000 +/** + * \brief Query a stream's capture state (11.3+) + * + * Query stream state related to stream capture. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not created + * with ::CU_STREAM_NON_BLOCKING is capturing, returns ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * Valid data (other than capture status) is returned only if both of the following are true: + * - the call returns CUDA_SUCCESS + * - the returned capture status is ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * This version of cuStreamGetCaptureInfo is introduced in CUDA 11.3 and will supplant the + * previous version in 12.0. Developers requiring compatibility across minor versions to + * CUDA 11.0 (driver version 445) should use ::cuStreamGetCaptureInfo or include a fallback + * path. + * + * \param hStream - The stream to query + * \param captureStatus_out - Location to return the capture status of the stream; required + * \param id_out - Optional location to return an id for the capture sequence, which is + * unique over the lifetime of the process + * \param graph_out - Optional location to return the graph being captured into. All + * operations other than destroy and node removal are permitted on the graph + * while the capture sequence is in progress. This API does not transfer + * ownership of the graph, which is transferred or destroyed at + * ::cuStreamEndCapture. Note that the graph handle may be invalidated before + * end of capture for certain errors. Nodes that are or become + * unreachable from the original stream at ::cuStreamEndCapture due to direct + * actions on the graph do not trigger ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED. + * \param dependencies_out - Optional location to store a pointer to an array of nodes. + * The next node to be captured in the stream will depend on this set of nodes, + * absent operations such as event wait which modify this set. The array pointer + * is valid until the next API call which operates on the stream or until end of + * capture. The node handles may be copied out and are valid until they or the + * graph is destroyed. The driver-owned array may also be passed directly to + * APIs that operate on the graph (not the stream) without copying. + * \param numDependencies_out - Optional location to store the size of the array + * returned in dependencies_out. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing, + * ::cuStreamUpdateCaptureDependencies + */ +CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); + +/** + * \brief Update the set of dependencies in a capturing stream (11.3+) + * + * Modifies the dependency set of a capturing stream. The dependency set is the set + * of nodes that the next captured node in the stream will depend on. + * + * Valid flags are ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES and + * ::CU_STREAM_SET_CAPTURE_DEPENDENCIES. These control whether the set passed to + * the API is added to the existing set or replaces it. A flags value of 0 defaults + * to ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES. + * + * Nodes that are removed from the dependency set via this API do not result in + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED if they are unreachable from the stream at + * ::cuStreamEndCapture. + * + * Returns ::CUDA_ERROR_ILLEGAL_STATE if the stream is not capturing. + * + * This API is new in CUDA 11.3. Developers requiring compatibility across minor + * versions to CUDA 11.0 should not use this API or provide a fallback. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_ILLEGAL_STATE + * + * \sa + * ::cuStreamBeginCapture, + * ::cuStreamGetCaptureInfo, + * ::cuStreamGetCaptureInfo_v2 + */ +CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); /** * \brief Attach memory to a stream asynchronously @@ -9235,8 +11851,6 @@ CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *ca */ CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, size_t length, unsigned int flags); -#endif /* __CUDA_API_VERSION >= 6000 */ - /** * \brief Determine status of a compute stream * @@ -9296,7 +11910,6 @@ CUresult CUDAAPI cuStreamQuery(CUstream hStream); */ CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); -#if __CUDA_API_VERSION >= 4000 /** * \brief Destroys a stream * @@ -9326,7 +11939,71 @@ CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); * ::cudaStreamDestroy */ CUresult CUDAAPI cuStreamDestroy(CUstream hStream); -#endif /* __CUDA_API_VERSION >= 4000 */ + +/** + * \brief Copies attributes from source stream to destination stream. + * + * Copies attributes from source stream \p src to destination stream \p dst. + * Both streams must have the same context. + * + * \param[out] dst Destination stream + * \param[in] src Source stream + * For list of attributes see ::CUstreamAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamCopyAttributes(CUstream dst, CUstream src); + +/** + * \brief Queries stream attribute. + * + * Queries attribute \p attr from \p hStream and stores it in corresponding + * member of \p value_out. + * + * \param[in] hStream + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, + CUstreamAttrValue *value_out); + +/** + * \brief Sets stream attribute. + * + * Sets attribute \p attr on \p hStream from corresponding attribute of + * \p value. The updated attribute will be applied to subsequent work + * submitted to the stream. It will not affect previously submitted work. + * + * \param[out] hStream + * \param[in] attr + * \param[in] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, + const CUstreamAttrValue *value); /** @} */ /* END CUDA_STREAM */ @@ -9421,10 +12098,60 @@ CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags); * ::cuStreamWaitEvent, * ::cuEventDestroy, * ::cuEventElapsedTime, - * ::cudaEventRecord + * ::cudaEventRecord, + * ::cuEventRecordWithFlags */ CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); +/** + * \brief Records an event + * + * Captures in \p hEvent the contents of \p hStream at the time of this call. + * \p hEvent and \p hStream must be from the same context. + * Calls such as ::cuEventQuery() or ::cuStreamWaitEvent() will then + * examine or wait for completion of the work that was captured. Uses of + * \p hStream after this call do not modify \p hEvent. See note on default + * stream behavior for what is captured in the default case. + * + * ::cuEventRecordWithFlags() can be called multiple times on the same event and + * will overwrite the previously captured state. Other APIs such as + * ::cuStreamWaitEvent() use the most recently captured state at the time + * of the API call, and are not affected by later calls to + * ::cuEventRecordWithFlags(). Before the first call to ::cuEventRecordWithFlags(), an + * event represents an empty set of work, so for example ::cuEventQuery() + * would return ::CUDA_SUCCESS. + * + * flags include: + * - ::CU_EVENT_RECORD_DEFAULT: Default event creation flag. + * - ::CU_EVENT_RECORD_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * + * \param hEvent - Event to record + * \param hStream - Stream to record event for + * \param flags - See ::CUevent_capture_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \note_null_stream + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuStreamWaitEvent, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cuEventRecord, + * ::cudaEventRecord + */ +CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, unsigned int flags); + /** * \brief Queries an event's status * @@ -9488,7 +12215,6 @@ CUresult CUDAAPI cuEventQuery(CUevent hEvent); */ CUresult CUDAAPI cuEventSynchronize(CUevent hEvent); -#if __CUDA_API_VERSION >= 4000 /** * \brief Destroys an event * @@ -9517,7 +12243,6 @@ CUresult CUDAAPI cuEventSynchronize(CUevent hEvent); * ::cudaEventDestroy */ CUresult CUDAAPI cuEventDestroy(CUevent hEvent); -#endif /* __CUDA_API_VERSION >= 4000 */ /** * \brief Computes the elapsed time between two events @@ -9577,8 +12302,6 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * @{ */ -#if __CUDA_API_VERSION >= 10000 - /** * \brief Imports an external memory object * @@ -9598,7 +12321,9 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven void *handle; const void *name; } win32; + const void *nvSciBufObject; } handle; + unsigned long long size; unsigned int flags; } CUDA_EXTERNAL_MEMORY_HANDLE_DESC; * \endcode @@ -9609,11 +12334,14 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * * \code typedef enum CUexternalMemoryHandleType_enum { - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, - CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, - CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, - CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5 + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 } CUexternalMemoryHandleType; * \endcode * @@ -9654,7 +12382,7 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be * NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle * is not NULL, then it must represent a valid shared NT handle that - * is returned by ID3DDevice::CreateSharedHandle when referring to a + * is returned by ID3D12Device::CreateSharedHandle when referring to a * ID3D12Heap object. This handle holds a reference to the underlying * object. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name * is not NULL, then it must point to a NULL-terminated array of @@ -9666,17 +12394,55 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be * NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle * is not NULL, then it must represent a valid shared NT handle that - * is returned by ID3DDevice::CreateSharedHandle when referring to a + * is returned by ID3D12Device::CreateSharedHandle when referring to a * ID3D12Resource object. This handle holds a reference to the * underlying object. If * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name * is not NULL, then it must point to a NULL-terminated array of * UTF-16 characters that refers to a ID3D12Resource object. * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must + * represent a valid shared NT handle that is returned by + * IDXGIResource1::CreateSharedHandle when referring to a + * ID3D11Resource object. If + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * is not NULL, then it must point to a NULL-terminated array of + * UTF-16 characters that refers to a ID3D11Resource object. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must + * represent a valid shared KMT handle that is returned by + * IDXGIResource::GetSharedHandle when referring to a + * ID3D11Resource object and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * must be NULL. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::nvSciBufObject must be non-NULL + * and reference a valid NvSciBuf object. + * If the NvSciBuf object imported into CUDA is also mapped by other drivers, then the + * application must use ::cuWaitExternalSemaphoresAsync or ::cuSignalExternalSemaphoresAsync + * as appropriate barriers to maintain coherence between CUDA and the other drivers. + * See ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC and ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC + * for memory synchronization. + * + * + * The size of the memory object must be specified in + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::size. + * * Specifying the flag ::CUDA_EXTERNAL_MEMORY_DEDICATED in * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::flags indicates that the * resource is a dedicated resource. The definition of what a * dedicated resource is outside the scope of this extension. + * This flag must be set if ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type + * is one of the following: + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT * * \param extMem_out - Returned handle to an external memory object * \param memHandleDesc - Memory import handle descriptor @@ -9684,6 +12450,7 @@ CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUeven * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_HANDLE * \notefnerr * @@ -9734,6 +12501,8 @@ CUresult CUDAAPI cuImportExternalMemory(CUexternalMemory *extMem_out, const CUDA * appropriate offsets to the returned pointer to derive the * individual buffers. * + * The returned pointer \p devPtr must be freed using ::cuMemFree. + * * \param devPtr - Returned device pointer to buffer * \param extMem - Handle to external memory object * \param bufferDesc - Buffer descriptor @@ -9741,10 +12510,11 @@ CUresult CUDAAPI cuImportExternalMemory(CUexternalMemory *extMem_out, const CUDA * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_HANDLE * \notefnerr * - * \sa ::cuImportExternalMemory + * \sa ::cuImportExternalMemory, * ::cuDestroyExternalMemory, * ::cuExternalMemoryGetMappedMipmappedArray */ @@ -9781,6 +12551,11 @@ CUresult CUDAAPI cuExternalMemoryGetMappedBuffer(CUdeviceptr *devPtr, CUexternal * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels specifies * the total number of levels in the mipmap chain. * + * If \p extMem was imported from a handle of type ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels must be equal to 1. + * + * The returned CUDA mipmapped array must be freed using ::cuMipmappedArrayDestroy. + * * \param mipmap - Returned CUDA mipmapped array * \param extMem - Handle to external memory object * \param mipmapDesc - CUDA array descriptor @@ -9788,22 +12563,23 @@ CUresult CUDAAPI cuExternalMemoryGetMappedBuffer(CUdeviceptr *devPtr, CUexternal * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, * ::CUDA_ERROR_INVALID_HANDLE * \notefnerr * - * \sa ::cuImportExternalMemory + * \sa ::cuImportExternalMemory, * ::cuDestroyExternalMemory, * ::cuExternalMemoryGetMappedBuffer */ CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray(CUmipmappedArray *mipmap, CUexternalMemory extMem, const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc); /** - * \brief Releases all resources associated with an external memory - * object. + * \brief Destroys an external memory object. * - * Frees all buffers and CUDA mipmapped arrays that were - * mapped onto this external memory object and releases any reference - * on the underlying memory itself. + * Destroys the specified external memory object. Any existing buffers + * and CUDA mipmapped arrays mapped onto this object must no longer be + * used and must be explicitly freed using ::cuMemFree and + * ::cuMipmappedArrayDestroy respectively. * * \param extMem - External memory object to be destroyed * @@ -9813,7 +12589,7 @@ CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray(CUmipmappedArray *mipma * ::CUDA_ERROR_INVALID_HANDLE * \notefnerr * - * \sa ::cuImportExternalMemory + * \sa ::cuImportExternalMemory, * ::cuExternalMemoryGetMappedBuffer, * ::cuExternalMemoryGetMappedMipmappedArray */ @@ -9836,8 +12612,9 @@ CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); int fd; struct { void *handle; - const void *name; + const void *name; } win32; + const void* NvSciSyncObj; } handle; unsigned int flags; } CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; @@ -9849,10 +12626,16 @@ CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); * * \code typedef enum CUexternalSemaphoreHandleType_enum { - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, - CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4 + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 } CUexternalSemaphoreHandleType; * \endcode * @@ -9867,7 +12650,7 @@ CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, then exactly one * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and - * ::cudaExternalSemaphoreHandleDesc::handle::win32::name must not be + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be * NULL. If * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle * is not NULL, then it must represent a valid shared NT handle that @@ -9890,23 +12673,76 @@ CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, then exactly one * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and - * ::cudaExternalSemaphoreHandleDesc::handle::win32::name must not be + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be * NULL. If * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle * is not NULL, then it must represent a valid shared NT handle that - * is returned by ID3DDevice::CreateSharedHandle when referring to a + * is returned by ID3D12Device::CreateSharedHandle when referring to a * ID3D12Fence object. This handle holds a reference to the underlying * object. If * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name * is not NULL, then it must name a valid synchronization object that * refers to a valid ID3D12Fence object. * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that is returned by + * ID3D11Fence::CreateSharedHandle. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid ID3D11Fence object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::nvSciSyncObj + * represents a valid NvSciSyncObj. + * + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that + * is returned by IDXGIResource1::CreateSharedHandle when referring to + * a IDXGIKeyedMutex object. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid IDXGIKeyedMutex object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared KMT handle that + * is returned by IDXGIResource::GetSharedHandle when referring to + * a IDXGIKeyedMutex object and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must be NULL. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a synchronization object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a synchronization object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object. + * * \param extSem_out - Returned handle to an external semaphore * \param semHandleDesc - Semaphore import handle descriptor * * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, * ::CUDA_ERROR_INVALID_HANDLE * \notefnerr * @@ -9932,20 +12768,49 @@ CUresult CUDAAPI cuImportExternalSemaphore(CUexternalSemaphore *extSem_out, cons * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT * then signaling the semaphore will set it to the signaled state. * - * If the semaphore object is of the type - * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, then the - * semaphore will be set to the value specified in + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then the semaphore will be set to the value specified in * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::fence::value. * + * If the semaphore object is of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC + * this API sets ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence + * to a value that can be used by subsequent waiters of the same NvSciSync object + * to order operations with those currently submitted in \p stream. Such an update + * will overwrite previous contents of + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence. By default, + * signaling such an external semaphore object causes appropriate memory synchronization + * operations to be performed over all external memory objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that any subsequent accesses + * made by other importers of the same set of NvSciBuf memory object(s) are coherent. + * These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC, which can be used as a + * performance optimization when data coherency is not required. But specifying this + * flag in scenarios where data coherency is required results in undefined behavior. + * Also, for semaphore object of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * if the NvSciSyncAttrList used to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_SIGNAL, this API will return + * CUDA_ERROR_NOT_SUPPORTED. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be released with the key specified in + * ::CUDA_EXTERNAL_SEMAPHORE_PARAMS::params::keyedmutex::key. + * * \param extSemArray - Set of external semaphores to be signaled * \param paramsArray - Array of semaphore parameters * \param numExtSems - Number of semaphores to signal - * \param stream - Stream to enqueue the signal operations in + * \param stream - Stream to enqueue the signal operations in * * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_HANDLE + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED * \notefnerr * * \sa ::cuImportExternalSemaphore, @@ -9973,12 +12838,44 @@ CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extS * unsignaled state. Therefore for every signal operation, there can * only be one wait operation. * - * If the semaphore object is of the type - * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, then waiting on - * the semaphore will wait until the value of the semaphore is - * greater than or equal to + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then waiting on the semaphore will wait until the value of the + * semaphore is greater than or equal to * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::fence::value. * + * If the semaphore object is of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC + * then, waiting on the semaphore will wait until the + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence is signaled by the + * signaler of the NvSciSyncObj that was associated with this semaphore object. + * By default, waiting on such an external semaphore object causes appropriate + * memory synchronization operations to be performed over all external memory objects + * that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that + * any subsequent accesses made by other importers of the same set of NvSciBuf memory + * object(s) are coherent. These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC, which can be used as a + * performance optimization when data coherency is not required. But specifying this + * flag in scenarios where data coherency is required results in undefined behavior. + * Also, for semaphore object of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * if the NvSciSyncAttrList used to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_WAIT, this API will return + * CUDA_ERROR_NOT_SUPPORTED. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be acquired when it is released with the key + * specified in ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::key + * or until the timeout specified by + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::timeoutMs + * has lapsed. The timeout interval can either be a finite value + * specified in milliseconds or an infinite value. In case an infinite + * value is specified the timeout never elapses. The windows INFINITE + * macro must be used to specify infinite timeout. + * * \param extSemArray - External semaphores to be waited on * \param paramsArray - Array of semaphore parameters * \param numExtSems - Number of semaphores to wait on @@ -9987,7 +12884,9 @@ CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extS * \return * ::CUDA_SUCCESS, * ::CUDA_ERROR_NOT_INITIALIZED, - * ::CUDA_ERROR_INVALID_HANDLE + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_TIMEOUT * \notefnerr * * \sa ::cuImportExternalSemaphore, @@ -10017,8 +12916,6 @@ CUresult CUDAAPI cuWaitExternalSemaphoresAsync(const CUexternalSemaphore *extSem */ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); -#endif /* __CUDA_API_VERSION >= 10000 */ - /** @} */ /* END CUDA_EXTRES_INTEROP */ /** @@ -10064,7 +12961,6 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); * @{ */ -#if __CUDA_API_VERSION >= 8000 /** * \brief Wait on a memory location * @@ -10097,7 +12993,7 @@ CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); * * \sa ::cuStreamWaitValue64, * ::cuStreamWriteValue32, - * ::cuStreamWriteValue64 + * ::cuStreamWriteValue64, * ::cuStreamBatchMemOp, * ::cuMemHostRegister, * ::cuStreamWaitEvent @@ -10242,7 +13138,6 @@ CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, cuuint6 * ::cuMemHostRegister */ CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); -#endif /* __CUDA_API_VERSION >= 8000 */ /** @} */ /* END CUDA_MEMOP */ @@ -10293,7 +13188,7 @@ CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstrea * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: The maximum size in bytes of * dynamically-allocated shared memory. * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared memory-L1 - * cache split ratio in percent of shared memory. + * cache split ratio in percent of total shared memory. * * \param pi - Returned attribute value * \param attrib - Attribute requested @@ -10312,13 +13207,11 @@ CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstrea * ::cuCtxSetCacheConfig, * ::cuFuncSetCacheConfig, * ::cuLaunchKernel, - * ::cudaFuncGetAttributes + * ::cudaFuncGetAttributes, * ::cudaFuncSetAttribute */ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunction hfunc); -#if __CUDA_API_VERSION >= 9000 - /** * \brief Sets information about a function * @@ -10339,8 +13232,9 @@ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunc * architecture. * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: On devices where the L1 * cache and shared memory use the same hardware resources, this sets the shared memory - * carveout preference, in percent of the total resources. This is only a hint, and the - * driver can choose a different ratio if required to execute the function. + * carveout preference, in percent of the total shared memory. + * See ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR + * This is only a hint, and the driver can choose a different ratio if required to execute the function. * * \param hfunc - Function to query attribute of * \param attrib - Attribute requested @@ -10359,11 +13253,10 @@ CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunc * ::cuCtxSetCacheConfig, * ::cuFuncSetCacheConfig, * ::cuLaunchKernel, - * ::cudaFuncGetAttributes + * ::cudaFuncGetAttributes, * ::cudaFuncSetAttribute */ CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); -#endif // __CUDA_API_VERSION >= 9000 /** * \brief Sets the preferred cache configuration for a device function @@ -10409,7 +13302,6 @@ CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attri */ CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); -#if __CUDA_API_VERSION >= 4020 /** * \brief Sets the shared memory configuration for a device function. * @@ -10461,9 +13353,33 @@ CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); * ::cudaFuncSetSharedMemConfig */ CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig config); -#endif -#if __CUDA_API_VERSION >= 4000 +/** + * \brief Returns a module handle + * + * Returns in \p *hmod the handle of the module that function \p hfunc + * is located in. The lifetime of the module corresponds to the lifetime of + * the context it was loaded in or until the module is explicitly unloaded. + * + * The CUDA runtime manages its own modules loaded into the primary context. + * If the handle returned by this API refers to a module loaded by the CUDA runtime, + * calling ::cuModuleUnload() on that module will result in undefined behavior. + * + * \param hmod - Returned module handle + * \param hfunc - Function to retrieve module for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + */ +CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); + /** * \brief Launches a CUDA function * @@ -10522,8 +13438,8 @@ CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig confi * parameters are specified with both \p kernelParams and \p extra * (i.e. both \p kernelParams and \p extra are non-NULL). * - * Calling ::cuLaunchKernel() sets persistent function state that is - * the same as function state set through the following deprecated APIs: + * Calling ::cuLaunchKernel() invalidates the persistent function state + * set through the following deprecated APIs: * ::cuFuncSetBlockShape(), * ::cuFuncSetSharedSize(), * ::cuParamSetSize(), @@ -10531,10 +13447,6 @@ CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig confi * ::cuParamSetf(), * ::cuParamSetv(). * - * When the kernel \p f is launched via ::cuLaunchKernel(), the previous - * block shape, shared size and parameter info associated with \p f - * is overwritten. - * * Note that to use ::cuLaunchKernel(), the kernel \p f must either have * been compiled with toolchain version 3.2 or later so that it will * contain kernel parameter information, or have no kernel parameters. @@ -10586,8 +13498,7 @@ CUresult CUDAAPI cuLaunchKernel(CUfunction f, CUstream hStream, void **kernelParams, void **extra); -#endif /* __CUDA_API_VERSION >= 4000 */ -#if __CUDA_API_VERSION >= 9000 + /** * \brief Launches a CUDA function where thread blocks can cooperate and synchronize as they execute * @@ -10678,6 +13589,8 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, /** * \brief Launches CUDA functions on multiple devices where thread blocks can cooperate and synchronize as they execute * + * \deprecated This function is deprecated as of CUDA 11.3. + * * Invokes kernels as specified in the \p launchParamsList array where each element * of the array specifies all the parameters required to perform a single kernel launch. * These kernels can cooperate and synchronize as they execute. The size of the array is @@ -10808,11 +13721,7 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, * ::cuLaunchCooperativeKernel, * ::cudaLaunchCooperativeKernelMultiDevice */ -CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice(CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, unsigned int flags); - -#endif /* __CUDA_API_VERSION >= 9000 */ - -#if __CUDA_API_VERSION >= 10000 +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice(CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, unsigned int flags); /** * \brief Enqueues a host function call in a stream @@ -10879,8 +13788,6 @@ CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice(CUDA_LAUNCH_PARAMS *launch */ CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, void *userData); -#endif /* __CUDA_API_VERSION >= 10000 */ - /** @} */ /* END CUDA_EXEC */ /** @@ -11107,6 +14014,21 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, voi * contains the number of threads specified by a previous call to * ::cuFuncSetBlockShape(). * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * * \param f - Kernel to launch * * \return @@ -11144,6 +14066,21 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f); * blocks. Each block contains the number of threads specified by a previous * call to ::cuFuncSetBlockShape(). * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * * \param f - Kernel to launch * \param grid_width - Width of grid in blocks * \param grid_height - Height of grid in blocks @@ -11183,6 +14120,21 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, in * blocks. Each block contains the number of threads specified by a previous * call to ::cuFuncSetBlockShape(). * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * * \param f - Kernel to launch * \param grid_width - Width of grid in blocks * \param grid_height - Height of grid in blocks @@ -11202,8 +14154,8 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, in * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED * * \note In certain cases where cubins are created with no ABI (i.e., using \p ptxas \p --abi-compile \p no), - * this function may serialize kernel launches. In order to force the CUDA driver to retain - * asynchronous behavior, set the ::CU_CTX_LMEM_RESIZE_TO_MAX flag during context creation (see ::cuCtxCreate). + * this function may serialize kernel launches. The CUDA driver retains asynchronous behavior by + * growing the per-thread stack as needed per launch and not shrinking it afterwards. * * \note_null_stream * \notefnerr @@ -11247,7 +14199,6 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, int grid_widt __CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, int texunit, CUtexref hTexRef); /** @} */ /* END CUDA_EXEC_DEPRECATED */ -#if __CUDA_API_VERSION >= 10000 /** * \defgroup CUDA_GRAPH Graph Management * @@ -11334,9 +14285,9 @@ CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); * parameter will be copied. The number of kernel parameters and their offsets and sizes do not need * to be specified as that information is retrieved directly from the kernel's image. * - * 2) Kernel parameters can also be packaged by the application into a single buffer that is passed in - * via \p extra. This places the burden on the application of knowing each kernel - * parameter's size and alignment/padding within the buffer. The \p extra parameter exists + * 2) Kernel parameters for non-cooperative kernels can also be packaged by the application into a single + * buffer that is passed in via \p extra. This places the burden on the application of knowing each + * kernel parameter's size and alignment/padding within the buffer. The \p extra parameter exists * to allow this function to take additional less commonly used arguments. \p extra specifies * a list of names of extra settings and their corresponding values. Each extra setting name is * immediately followed by the corresponding value. The list must be terminated with either NULL or @@ -11354,8 +14305,8 @@ CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); * ::CU_LAUNCH_PARAM_BUFFER_POINTER; * * The error ::CUDA_ERROR_INVALID_VALUE will be returned if kernel parameters are specified with both - * \p kernelParams and \p extra (i.e. both \p kernelParams and - * \p extra are non-NULL). + * \p kernelParams and \p extra (i.e. both \p kernelParams and \p extra are non-NULL). + * ::CUDA_ERROR_INVALID_VALUE will be returned if \p extra is used for a cooperative kernel. * * The \p kernelParams or \p extra array, as well as the argument values it points to, * are copied during this call. @@ -11380,6 +14331,7 @@ CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); * * \sa * ::cuLaunchKernel, + * ::cuLaunchCooperativeKernel, * ::cuGraphKernelNodeGetParams, * ::cuGraphKernelNodeSetParams, * ::cuGraphCreate, @@ -11390,7 +14342,7 @@ CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); * ::cuGraphAddMemcpyNode, * ::cuGraphAddMemsetNode */ -CUresult CUDAAPI cuGraphAddKernelNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams); +CUresult CUDAAPI cuGraphAddKernelNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams); /** * \brief Returns a kernel node's parameters @@ -11493,7 +14445,7 @@ CUresult CUDAAPI cuGraphKernelNodeSetParams(CUgraphNode hNode, const CUDA_KERNEL * ::cuGraphAddHostNode, * ::cuGraphAddMemsetNode */ -CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMCPY3D *copyParams, CUcontext ctx); +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMCPY3D *copyParams, CUcontext ctx); /** * \brief Returns a memcpy node's parameters @@ -11581,7 +14533,7 @@ CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, const CUDA_MEMCPY * ::cuGraphAddHostNode, * ::cuGraphAddMemcpyNode */ -CUresult CUDAAPI cuGraphAddMemsetNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); +CUresult CUDAAPI cuGraphAddMemsetNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); /** * \brief Returns a memset node's parameters @@ -11639,6 +14591,7 @@ CUresult CUDAAPI cuGraphMemsetNodeSetParams(CUgraphNode hNode, const CUDA_MEMSET * A handle to the new node will be returned in \p phGraphNode. * * When the graph is launched, the node will invoke the specified CPU function. + * Host nodes are not supported under MPS with pre-Volta GPUs. * * \param phGraphNode - Returns newly created node * \param hGraph - Graph to which to add the node @@ -11650,6 +14603,7 @@ CUresult CUDAAPI cuGraphMemsetNodeSetParams(CUgraphNode hNode, const CUDA_MEMSET * ::CUDA_SUCCESS, * ::CUDA_ERROR_DEINITIALIZED, * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, * ::CUDA_ERROR_INVALID_VALUE * \note_graph_thread_safety * \notefnerr @@ -11666,7 +14620,7 @@ CUresult CUDAAPI cuGraphMemsetNodeSetParams(CUgraphNode hNode, const CUDA_MEMSET * ::cuGraphAddMemcpyNode, * ::cuGraphAddMemsetNode */ -CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies, const CUDA_HOST_NODE_PARAMS *nodeParams); +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_HOST_NODE_PARAMS *nodeParams); /** * \brief Returns a host node's parameters @@ -11723,6 +14677,8 @@ CUresult CUDAAPI cuGraphHostNodeSetParams(CUgraphNode hNode, const CUDA_HOST_NOD * at the root of the graph. \p dependencies may not have any duplicate entries. * A handle to the new node will be returned in \p phGraphNode. * + * If \p hGraph contains allocation or free nodes, this call will return an error. + * * The node executes an embedded child graph. The child graph is cloned in this call. * * \param phGraphNode - Returns newly created node @@ -11750,7 +14706,7 @@ CUresult CUDAAPI cuGraphHostNodeSetParams(CUgraphNode hNode, const CUDA_HOST_NOD * ::cuGraphAddMemsetNode, * ::cuGraphClone */ -CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies, CUgraph childGraph); +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUgraph childGraph); /** * \brief Gets a handle to the embedded graph of a child graph node @@ -11759,6 +14715,9 @@ CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, CUgraph hGra * does not clone the graph. Changes to the graph will be reflected in * the node, and the node retains ownership of the graph. * + * Allocation and free nodes cannot be added to the returned graph. + * Attempting to do so will return an error. + * * \param hNode - Node to get the embedded graph for * \param phGraph - Location to store a handle to the graph * @@ -11812,12 +14771,654 @@ CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, CUgraph *phGra * ::cuGraphAddMemcpyNode, * ::cuGraphAddMemsetNode */ -CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, CUgraphNode *dependencies, size_t numDependencies); +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies); + +/** + * \brief Creates an event record node and adds it to a graph + * + * Creates a new event record node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and event specified in \p event. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * Each launch of the graph will record \p event to capture execution of the + * node's dependencies. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddEventRecordNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); + +/** + * \brief Returns the event associated with an event record node + * + * Returns the event of event record node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeGetEvent(CUgraphNode hNode, CUevent *event_out); + +/** + * \brief Sets an event record node's event + * + * Sets the event of event record node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeSetEvent(CUgraphNode hNode, CUevent event); + +/** + * \brief Creates an event wait node and adds it to a graph + * + * Creates a new event wait node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and event specified in \p event. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * The graph node will wait for all work captured in \p event. See ::cuEventRecord() + * for details on what is captured by an event. \p event may be from a different context + * or device than the launch stream. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddEventWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); + +/** + * \brief Returns the event associated with an event wait node + * + * Returns the event of event wait node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeGetEvent(CUgraphNode hNode, CUevent *event_out); + +/** + * \brief Sets an event wait node's event + * + * Sets the event of event wait node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeSetEvent(CUgraphNode hNode, CUevent event); + +/** + * \brief Creates an external semaphore signal node and adds it to a graph + * + * Creates a new external semaphore signal node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments specified + * in \p nodeParams. It is possible for \p numDependencies to be 0, in which case the + * node will be placed at the root of the graph. \p dependencies may not have any + * duplicate entries. A handle to the new node will be returned in \p phGraphNode. + * + * Performs a signal operation on a set of externally allocated semaphore objects + * when the node is launched. The operation(s) will occur after all of the node's + * dependencies have completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExternalSemaphoresSignalNodeGetParams, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresSignalNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore signal node's parameters + * + * Returns the parameters of an external semaphore signal node \p hNode in \p params_out. + * The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphExternalSemaphoresSignalNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeGetParams(CUgraphNode hNode, CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore signal node's parameters + * + * Sets the parameters of an external semaphore signal node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeSetParams(CUgraphNode hNode, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an external semaphore wait node and adds it to a graph + * + * Creates a new external semaphore wait node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * Performs a wait operation on a set of externally allocated semaphore objects + * when the node is launched. The node's dependencies will not be launched until + * the wait operation has completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExternalSemaphoresWaitNodeGetParams, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore wait node's parameters + * + * Returns the parameters of an external semaphore wait node \p hNode in \p params_out. + * The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphExternalSemaphoresSignalNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeGetParams(CUgraphNode hNode, CUDA_EXT_SEM_WAIT_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore wait node's parameters + * + * Sets the parameters of an external semaphore wait node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an allocation node and adds it to a graph + * + * Creates a new allocation node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in + * \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. + * + * If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode, + * the allocation can be accessed by nodes ordered after the allocation node but before the free node. + * These allocations cannot be freed outside the owning graph, and they can only be freed once in the + * owning graph. + * + * If the allocation is not freed in the same graph, then it can be accessed not only by nodes in the + * graph which are ordered after the allocation node, but also by stream operations ordered after the + * graph's execution but before the allocation is freed. + * + * Allocations which are not freed in the same graph can be freed by: + * - passing the allocation to ::cuMemFreeAsync or ::cuMemFree; + * - launching a graph with a free node for that allocation; or + * - specifying ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH during instantiation, which makes + * each launch behave as though it called ::cuMemFreeAsync for every unfreed allocation. + * + * It is not possible to free an allocation in both the owning graph and another graph. If the allocation + * is freed in the same graph, a free node cannot be added to another graph. If the allocation is freed + * in another graph, a free node can no longer be added to the owning graph. + * + * The following restrictions apply to graphs which contain allocation and/or memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemAllocNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUDA_MEM_ALLOC_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a memory alloc node's parameters + * + * Returns the parameters of a memory alloc node \p hNode in \p params_out. + * The \p poolProps and \p accessDescs returned in \p params_out, are owned by the + * node. This memory remains valid until the node is destroyed. The returned + * parameters must not be modified. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams + */ +CUresult CUDAAPI cuGraphMemAllocNodeGetParams(CUgraphNode hNode, CUDA_MEM_ALLOC_NODE_PARAMS *params_out); + +/** + * \brief Creates a memory free node and adds it to a graph + * + * Creates a new memory free node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param dptr - Address of memory to free + * + * ::cuGraphAddMemFreeNode will return ::CUDA_ERROR_INVALID_VALUE if the user attempts to free: + * - an allocation twice in the same graph. + * - an address that was not returned by an allocation node. + * - an invalid address. + * + * The following restrictions apply to graphs which contain allocation and/or memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemFreeNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUdeviceptr dptr); + +/** + * \brief Returns a memory free node's parameters + * + * Returns the address of a memory free node \p hNode in \p dptr_out. + * + * \param hNode - Node to get the parameters for + * \param dptr_out - Pointer to return the device address + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams + */ +CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dptr_out); + +/** + * \brief Free unused memory that was cached on the specified device for use with graphs back to the OS. + * + * Blocks which are not in use by a graph that is either currently executing or scheduled to execute are + * freed back to the operating system. + * + * \param device - The device for which cached memory should be freed. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode, + * ::cuDeviceSetGraphMemAttribute, + * ::cuDeviceGetGraphMemAttribute + */ +CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); + +/** + * \brief Query asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT: Amount of memory, in bytes, currently associated with graphs + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT: Amount of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - retrieved value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceSetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attribute attr, void* value); + +/** + * \brief Set asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - pointer to value to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceGetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceSetGraphMemAttribute(CUdevice device, CUgraphMem_attribute attr, void* value); /** * \brief Clones a graph * - * This function creates a copy of \p originalGraph and returns it in \p * phGraphClone. + * This function creates a copy of \p originalGraph and returns it in \p phGraphClone. * All parameters are copied into the cloned graph. The original graph may be modified * after this call without affecting the clone. * @@ -12082,7 +15683,7 @@ CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, CUgraphNode *de * ::cuGraphNodeGetDependencies, * ::cuGraphNodeGetDependentNodes */ -CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, CUgraphNode *from, CUgraphNode *to, size_t numDependencies); +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, const CUgraphNode *to, size_t numDependencies); /** * \brief Removes dependency edges from a graph @@ -12094,6 +15695,9 @@ CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, CUgraphNode *from, CUgra * If \p numDependencies is 0, elements in \p from and \p to will be ignored. * Specifying a non-existing dependency will return an error. * + * Dependencies cannot be removed from graphs which contain allocation or free nodes. + * Any attempt to do so will return an error. + * * \param hGraph - Graph from which to remove dependencies * \param from - Array of nodes that provide the dependencies * \param to - Array of dependent nodes @@ -12111,7 +15715,7 @@ CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, CUgraphNode *from, CUgra * ::cuGraphNodeGetDependencies, * ::cuGraphNodeGetDependentNodes */ -CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, CUgraphNode *from, CUgraphNode *to, size_t numDependencies); +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, const CUgraphNode *from, const CUgraphNode *to, size_t numDependencies); /** * \brief Remove a node from the graph @@ -12119,6 +15723,9 @@ CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, CUgraphNode *from, CU * Removes \p hNode from its graph. This operation also severs any dependencies of other nodes * on \p hNode and vice versa. * + * Nodes which belong to a graph which contains allocation or free nodes cannot be destroyed. + * Any attempt to do so will return an error. + * * \param hNode - Node to remove * * \return @@ -12143,7 +15750,7 @@ CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); * Instantiates \p hGraph as an executable graph. The graph is validated for any * structural constraints or intra-node constraints which were not previously * validated. If instantiation is successful, a handle to the instantiated graph - * is returned in \p graphExec. + * is returned in \p phGraphExec. * * If there are any errors, diagnostic information may be returned in \p errorNode and * \p logBuffer. This is the primary way to inspect instantiation errors. The output @@ -12167,12 +15774,461 @@ CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); * \notefnerr * * \sa + * ::cuGraphInstantiateWithFlags, * ::cuGraphCreate, + * ::cuGraphUpload, * ::cuGraphLaunch, * ::cuGraphExecDestroy */ CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); +/** + * \brief Creates an executable graph from a graph + * + * Instantiates \p hGraph as an executable graph. The graph is validated for any + * structural constraints or intra-node constraints which were not previously + * validated. If instantiation is successful, a handle to the instantiated graph + * is returned in \p phGraphExec. + * + * The \p flags parameter controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a + * graph containing memory allocation nodes to automatically free any + * unfreed memory allocations before the graph is relaunched. + * + * If \p hGraph contains any allocation or free nodes, there can be at most one + * executable graph in existence for that graph at a time. + * + * An attempt to instantiate a second executable graph before destroying the first + * with ::cuGraphExecDestroy will result in an error. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param flags - Flags to control instantiation. See ::CUgraphInstantiate_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphCreate, + * ::cuGraphUpload, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphInstantiateWithFlags(CUgraphExec *phGraphExec, CUgraph hGraph, unsigned long long flags); + +/** + * \brief Sets the parameters for a kernel node in the given graphExec + * + * Sets the parameters of a kernel node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. The \p func field + * of \p nodeParams cannot be modified and must match the original value. + * All other values can be modified. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - kernel node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for a memcpy node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p copyParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The source and destination memory in \p copyParams must be allocated from the same + * contexts as the original source and destination memory. Both the instantiation-time + * memory operands and the memory operands in \p copyParams must be 1-dimensional. + * Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operands' mappings changed or + * either the original or new memory operands are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memcpy node from the graph which was used to instantiate graphExec + * \param copyParams - The updated parameters to set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemcpyNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_MEMCPY3D *copyParams, CUcontext ctx); + +/** + * \brief Sets the parameters for a memset node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p memsetParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The destination memory in \p memsetParams must be allocated from the same + * contexts as the original destination memory. Both the instantiation-time + * memory operand and the memory operand in \p memsetParams must be 1-dimensional. + * Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operand's mappings changed or + * either the original or new memory operand are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memset node from the graph which was used to instantiate graphExec + * \param memsetParams - The updated parameters to set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemsetNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); + +/** + * \brief Sets the parameters for a host node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p nodeParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate graphExec + * \param nodeParams - The updated parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecHostNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Updates node parameters in the child graph node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though the nodes contained + * in \p hNode's graph had the parameters contained in \p childGraph's nodes at instantiation. + * \p hNode must remain in the graph which was used to instantiate \p hGraphExec. + * Changed edges to and from \p hNode are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. \p hNode is also + * not modified by this call. + * + * The topology of \p childGraph, as well as the node insertion order, must match that + * of the graph contained in \p hNode. See ::cuGraphExecUpdate() for a list of restrictions + * on what can be updated in an instantiated graph. The update is recursive, so child graph + * nodes contained within the top level child graph will also be updated. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate graphExec + * \param childGraph - The graph supplying the updated parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecChildGraphNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, CUgraph childGraph); + +/** + * \brief Sets the event for an event record node in the given graphExec + * + * Sets the event of an event record node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event record node from the graph from which graphExec was instantiated + * \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventRecordNodeSetEvent(CUgraphExec hGraphExec, CUgraphNode hNode, CUevent event); + +/** + * \brief Sets the event for an event wait node in the given graphExec + * + * Sets the event of an event wait node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event wait node from the graph from which graphExec was instantiated + * \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventWaitNodeSetEvent(CUgraphExec hGraphExec, CUgraphNode hNode, CUevent event); + +/** + * \brief Sets the parameters for an external semaphore signal node in the given graphExec + * + * Sets the parameters of an external semaphore signal node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore signal node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresSignalNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for an external semaphore wait node in the given graphExec + * + * Sets the parameters of an external semaphore wait node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore wait node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresWaitNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Uploads an executable graph in a stream + * + * Uploads \p hGraphExec to the device in \p hStream without executing it. Uploads of + * the same \p hGraphExec will be serialized. Each upload is ordered behind both any + * previous work in \p hStream and any previous launches of \p hGraphExec. + * Uses memory cached by \p stream to back the allocations owned by \p hGraphExec. + * + * \param hGraphExec - Executable graph to upload + * \param hStream - Stream in which to upload the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraphExec, CUstream hStream); + /** * \brief Launches an executable graph in a stream * @@ -12181,6 +16237,10 @@ CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CU * and any previous launches of \p hGraphExec. To execute a graph concurrently, it must be * instantiated multiple times into multiple executable graphs. * + * If any allocations created by \p hGraphExec remain unfreed (from a previous launch) and + * \p hGraphExec was not instantiated with ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, + * the launch will fail with ::CUDA_ERROR_INVALID_VALUE. + * * \param hGraphExec - Executable graph to launch * \param hStream - Stream in which to launch the graph * @@ -12194,6 +16254,7 @@ CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CU * * \sa * ::cuGraphInstantiate, + * ::cuGraphUpload, * ::cuGraphExecDestroy */ CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream); @@ -12218,6 +16279,7 @@ CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream); * * \sa * ::cuGraphInstantiate, + * ::cuGraphUpload, * ::cuGraphLaunch */ CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec); @@ -12241,10 +16303,307 @@ CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec); * ::cuGraphCreate */ CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); -/** @} */ /* END CUDA_GRAPH */ -#endif /* __CUDA_API_VERSION >= 10000 */ -#if __CUDA_API_VERSION >= 6050 +/** + * \brief Check whether an executable graph can be updated with a graph and perform the update if possible + * + * Updates the node parameters in the instantiated graph specified by \p hGraphExec with the + * node parameters in a topologically identical graph specified by \p hGraph. + * + * Limitations: + * + * - Kernel nodes: + * - The owning context of the function cannot change. + * - A node whose function originally did not use CUDA dynamic parallelism cannot be updated + * to a function which uses CDP + * - Memset and memcpy nodes: + * - The CUDA device(s) to which the operand(s) was allocated/mapped cannot change. + * - The source/destination memory must be allocated from the same contexts as the original + * source/destination memory. + * - Only 1D memsets can be changed. + * - Additional memcpy node restrictions: + * - Changing either the source or destination memory type(i.e. CU_MEMORYTYPE_DEVICE, + * CU_MEMORYTYPE_ARRAY, etc.) is not supported. + * - External semaphore wait nodes and record nodes: + * - Changing the number of semaphores is not supported. + * + * Note: The API may add further restrictions in future releases. The return code should always be checked. + * + * cuGraphExecUpdate sets \p updateResult_out to CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED under + * the following conditions: + * + * - The count of nodes directly in \p hGraphExec and \p hGraph differ, in which case \p hErrorNode_out + * is NULL. + * - A node is deleted in \p hGraph but not not its pair from \p hGraphExec, in which case \p hErrorNode_out + * is NULL. + * - A node is deleted in \p hGraphExec but not its pair from \p hGraph, in which case \p hErrorNode_out is + * the pairless node from \p hGraph. + * - The dependent nodes of a pair differ, in which case \p hErrorNode_out is the node from \p hGraph. + * + * cuGraphExecUpdate sets \p updateResult_out to: + * - CU_GRAPH_EXEC_UPDATE_ERROR if passed an invalid value. + * - CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED if the graph topology changed + * - CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED if the type of a node changed, in which case + * \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE if the function changed in an unsupported + * way(see note above), in which case \p hErrorNode_out is set to the node from \p hGraph + * - CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED if any parameters to a node changed in a way + * that is not supported, in which case \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED if something about a node is unsupported, like + * the node's type or configuration, in which case \p hErrorNode_out is set to the node from \p hGraph + * + * If \p updateResult_out isn't set in one of the situations described above, the update check passes + * and cuGraphExecUpdate updates \p hGraphExec to match the contents of \p hGraph. If an error happens + * during the update, \p updateResult_out will be set to CU_GRAPH_EXEC_UPDATE_ERROR; otherwise, + * \p updateResult_out is set to CU_GRAPH_EXEC_UPDATE_SUCCESS. + * + * cuGraphExecUpdate returns CUDA_SUCCESS when the updated was performed successfully. It returns + * CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE if the graph update was not performed because it included + * changes which violated constraints specific to instantiated graph update. + * + * \param hGraphExec The instantiated graph to be updated + * \param hGraph The graph containing the updated parameters + * \param hErrorNode_out The node which caused the permissibility check to forbid the update, if any + * \param updateResult_out Whether the graph update was permitted. If was forbidden, the reason why + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + */ +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, CUgraphNode *hErrorNode_out, CUgraphExecUpdateResult *updateResult_out); + +/** + * \brief Copies attributes from source node to destination node. + * + * Copies attributes from source node \p src to destination node \p dst. + * Both node must have the same context. + * + * \param[out] dst Destination node + * \param[in] src Source node + * For list of attributes see ::CUkernelNodeAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeCopyAttributes(CUgraphNode dst, CUgraphNode src); + +/** + * \brief Queries node attribute. + * + * Queries attribute \p attr from node \p hNode and stores it in corresponding + * member of \p value_out. + * + * \param[in] hNode + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeGetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + CUkernelNodeAttrValue *value_out); + +/** + * \brief Sets node attribute. + * + * Sets attribute \p attr on node \p hNode from corresponding attribute of + * \p value. + * + * \param[out] hNode + * \param[in] attr + * \param[out] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeSetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + const CUkernelNodeAttrValue *value); + +/** + * \brief Write a DOT file describing graph structure + * + * Using the provided \p hGraph, write to \p path a DOT formatted description of the graph. + * By default this includes the graph topology, node types, node id, kernel names and memcpy direction. + * \p flags can be specified to write more detailed information about each node type such as + * parameter values, kernel attributes, node and function handles. + * + * \param hGraph - The graph to create a DOT file from + * \param path - The path to write the DOT file to + * \param flags - Flags from CUgraphDebugDot_flags for specifying which additional node information to write + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OPERATING_SYSTEM + */ +CUresult CUDAAPI cuGraphDebugDotPrint(CUgraph hGraph, const char *path, unsigned int flags); + +/** + * \brief Create a user object + * + * Create a user object with the specified destructor callback and initial reference count. The + * initial references are owned by the caller. + * + * Destructor callbacks cannot make CUDA API calls and should avoid blocking behavior, as they + * are executed by a shared internal thread. Another thread may be signaled to perform such + * actions, if it does not block forward progress of tasks scheduled through CUDA. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object_out - Location to return the user object handle + * \param ptr - The pointer to pass to the destroy function + * \param destroy - Callback to free the user object when it is no longer in use + * \param initialRefcount - The initial refcount to create the object with, typically 1. The + * initial references are owned by the calling thread. + * \param flags - Currently it is required to pass ::CU_USER_OBJECT_NO_DESTRUCTOR_SYNC, + * which is the only defined flag. This indicates that the destroy + * callback cannot be waited on by any CUDA API. Users requiring + * synchronization of the callback should signal its completion + * manually. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectCreate(CUuserObject *object_out, void *ptr, CUhostFn destroy, + unsigned int initialRefcount, unsigned int flags); + +/** + * \brief Retain a reference to a user object + * + * Retains new references to a user object. The new references are owned by the caller. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object - The object to retain + * \param count - The number of references to retain, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRetain(CUuserObject object, unsigned int count); + +/** + * \brief Release a reference to a user object + * + * Releases user object references owned by the caller. The object's destructor is invoked if + * the reference count reaches zero. + * + * It is undefined behavior to release references not owned by the caller, or to use a user + * object handle after all references are released. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object - The object to release + * \param count - The number of references to release, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRelease(CUuserObject object, unsigned int count); + +/** + * \brief Retain a reference to a user object from a graph + * + * Creates or moves user object references that will be owned by a CUDA graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param graph - The graph to associate the reference with + * \param object - The user object to retain a reference for + * \param count - The number of references to add to the graph, typically 1. Must be + * nonzero and not larger than INT_MAX. + * \param flags - The optional flag ::CU_GRAPH_USER_OBJECT_MOVE transfers references + * from the calling thread, rather than create new references. Pass 0 + * to create new references. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphRetainUserObject(CUgraph graph, CUuserObject object, unsigned int count, unsigned int flags); + +/** + * \brief Release a user object reference from a graph + * + * Releases user object references owned by a graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param graph - The graph that will release the reference + * \param object - The user object to release a reference for + * \param count - The number of references to release, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphReleaseUserObject(CUgraph graph, CUuserObject object, unsigned int count); + +/** @} */ /* END CUDA_GRAPH */ + /** * \defgroup CUDA_OCCUPANCY Occupancy * @@ -12422,17 +16781,39 @@ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize(int *minGridSize, int *blockSi */ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int *blockSize, CUfunction func, CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, int blockSizeLimit, unsigned int flags); +/** + * \brief Returns dynamic shared memory available per block when launching \p numBlocks blocks on SM + * + * Returns in \p *dynamicSmemSize the maximum size of dynamic shared memory to allow \p numBlocks blocks per SM. + * + * \param dynamicSmemSize - Returned maximum dynamic shared memory + * \param func - Kernel function for which occupancy is calculated + * \param numBlocks - Number of blocks to fit on SM + * \param blockSize - Size of the blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + */ +CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize); + /** @} */ /* END CUDA_OCCUPANCY */ -#endif /* __CUDA_API_VERSION >= 6050 */ /** - * \defgroup CUDA_TEXREF Texture Reference Management + * \defgroup CUDA_TEXREF_DEPRECATED Texture Reference Management [DEPRECATED] * - * ___MANBRIEF___ texture reference management functions of the low-level CUDA - * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * ___MANBRIEF___ deprecated texture reference management functions of the + * low-level CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ * - * This section describes the texture reference management functions of the - * low-level CUDA driver application programming interface. + * This section describes the deprecated texture reference management + * functions of the low-level CUDA driver application programming interface. * * @{ */ @@ -12440,6 +16821,8 @@ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int /** * \brief Binds an array as a texture reference * + * \deprecated + * * Binds the CUDA array \p hArray to the texture reference \p hTexRef. Any * previous address or CUDA array state associated with the texture reference * is superseded by this function. \p Flags must be set to @@ -12464,11 +16847,13 @@ CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToArray */ -CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, unsigned int Flags); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, unsigned int Flags); /** * \brief Binds a mipmapped array to a texture reference * + * \deprecated + * * Binds the CUDA mipmapped array \p hMipmappedArray to the texture reference \p hTexRef. * Any previous address or CUDA array state associated with the texture reference * is superseded by this function. \p Flags must be set to ::CU_TRSA_OVERRIDE_FORMAT. @@ -12492,12 +16877,13 @@ CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, unsigned int * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags); -#if __CUDA_API_VERSION >= 3020 /** * \brief Binds an address as a texture reference * + * \deprecated + * * Binds a linear address range to the texture reference \p hTexRef. Any * previous address or CUDA array state associated with the texture reference * is superseded by this function. Any memory previously bound to \p hTexRef @@ -12537,11 +16923,13 @@ CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, CUmipmappedArray hM * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTexture */ -CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, CUdeviceptr dptr, size_t bytes); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, CUdeviceptr dptr, size_t bytes); /** * \brief Binds an address as a 2D texture reference * + * \deprecated + * * Binds a linear address range to the texture reference \p hTexRef. Any * previous address or CUDA array state associated with the texture reference * is superseded by this function. Any memory previously bound to \p hTexRef @@ -12590,12 +16978,13 @@ CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, CUdevi * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTexture2D */ -CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); -#endif /* __CUDA_API_VERSION >= 3020 */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); /** * \brief Sets the format for a texture reference * + * \deprecated + * * Specifies the format of the data to be read by the texture reference * \p hTexRef. \p fmt and \p NumPackedComponents are exactly analogous to the * ::Format and ::NumChannels members of the ::CUDA_ARRAY_DESCRIPTOR structure: @@ -12624,11 +17013,13 @@ CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIP * ::cudaBindTextureToArray, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, int NumPackedComponents); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, int NumPackedComponents); /** * \brief Sets the addressing mode for a texture reference * + * \deprecated + * * Specifies the addressing mode \p am for the given dimension \p dim of the * texture reference \p hTexRef. If \p dim is zero, the addressing mode is * applied to the first parameter of the functions used to fetch from the @@ -12668,11 +17059,13 @@ CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, int Num * ::cudaBindTextureToArray, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, CUaddress_mode am); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, CUaddress_mode am); /** * \brief Sets the filtering mode for a texture reference * + * \deprecated + * * Specifies the filtering mode \p fm to be used when reading memory through * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: * @@ -12702,11 +17095,13 @@ CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, CUaddress_mod * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToArray */ -CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm); /** * \brief Sets the mipmap filtering mode for a texture reference * + * \deprecated + * * Specifies the mipmap filtering mode \p fm to be used when reading memory through * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: * @@ -12736,11 +17131,13 @@ CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm); * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm); /** * \brief Sets the mipmap level bias for a texture reference * + * \deprecated + * * Specifies the mipmap level bias \p bias to be added to the specified mipmap level when * reading memory through the texture reference \p hTexRef. * @@ -12763,11 +17160,13 @@ CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm) * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias); /** * \brief Sets the mipmap min/max mipmap level clamps for a texture reference * + * \deprecated + * * Specifies the min/max mipmap level clamps, \p minMipmapLevelClamp and \p maxMipmapLevelClamp * respectively, to be used when reading memory through the texture reference * \p hTexRef. @@ -12792,11 +17191,13 @@ CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias); * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp); /** * \brief Sets the maximum anisotropy for a texture reference * + * \deprecated + * * Specifies the maximum anisotropy \p maxAniso to be used when reading memory through * the texture reference \p hTexRef. * @@ -12820,11 +17221,13 @@ CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, float minMipmapLe * ::cudaBindTextureToArray, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso); /** * \brief Sets the border color for a texture reference * + * \deprecated + * * Specifies the value of the RGBA color via the \p pBorderColor to the texture reference * \p hTexRef. The color value supports only float type and holds color components in * the following sequence: @@ -12854,11 +17257,13 @@ CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAnis * ::cudaBindTextureToArray, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); /** * \brief Sets the flags for a texture reference * + * \deprecated + * * Specifies optional flags via \p Flags to specify the behavior of data * returned through the texture reference \p hTexRef. The valid flags are: * @@ -12872,6 +17277,10 @@ CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); * from [0, Dim) where Dim is the width or height of the CUDA * array. Instead, the texture coordinates [0, 1.0) reference * the entire breadth of the array dimension; + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. * * \param hTexRef - Texture reference * \param Flags - Optional flags to set @@ -12893,12 +17302,13 @@ CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); * ::cudaBindTextureToArray, * ::cudaBindTextureToMipmappedArray */ -CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags); -#if __CUDA_API_VERSION >= 3020 /** * \brief Gets the address associated with a texture reference * + * \deprecated + * * Returns in \p *pdptr the base address bound to the texture reference * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference * is not bound to any device memory range. @@ -12919,12 +17329,13 @@ CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags); * ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef); -#endif /* __CUDA_API_VERSION >= 3020 */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef); /** * \brief Gets the array bound to a texture reference * + * \deprecated + * * Returns in \p *phArray the CUDA array bound to the texture reference * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference * is not bound to any CUDA array. @@ -12945,11 +17356,13 @@ CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef); * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef); /** * \brief Gets the mipmapped array bound to a texture reference * + * \deprecated + * * Returns in \p *phMipmappedArray the CUDA mipmapped array bound to the texture * reference \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference * is not bound to any CUDA mipmapped array. @@ -12970,11 +17383,13 @@ CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef); * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, CUtexref hTexRef); /** * \brief Gets the addressing mode used by a texture reference * + * \deprecated + * * Returns in \p *pam the addressing mode corresponding to the * dimension \p dim of the texture reference \p hTexRef. Currently, the only * valid value for \p dim are 0 and 1. @@ -12996,11 +17411,13 @@ CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, C * ::cuTexRefGetAddress, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, int dim); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, int dim); /** * \brief Gets the filter-mode used by a texture reference * + * \deprecated + * * Returns in \p *pfm the filtering mode of the texture reference * \p hTexRef. * @@ -13020,11 +17437,13 @@ CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, i * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); /** * \brief Gets the format used by a texture reference * + * \deprecated + * * Returns in \p *pFormat and \p *pNumChannels the format and number * of components of the CUDA array bound to the texture reference \p hTexRef. * If \p pFormat or \p pNumChannels is NULL, it will be ignored. @@ -13046,11 +17465,13 @@ CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags */ -CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, CUtexref hTexRef); /** * \brief Gets the mipmap filtering mode for a texture reference * + * \deprecated + * * Returns the mipmap filtering mode in \p pfm that's used when reading memory through * the texture reference \p hTexRef. * @@ -13070,11 +17491,13 @@ CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, C * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); /** * \brief Gets the mipmap level bias for a texture reference * + * \deprecated + * * Returns the mipmap level bias in \p pBias that's added to the specified mipmap * level when reading memory through the texture reference \p hTexRef. * @@ -13094,11 +17517,13 @@ CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRe * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef); /** * \brief Gets the min/max mipmap level clamps for a texture reference * + * \deprecated + * * Returns the min/max mipmap level clamps in \p pminMipmapLevelClamp and \p pmaxMipmapLevelClamp * that's used when reading memory through the texture reference \p hTexRef. * @@ -13119,11 +17544,13 @@ CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef); * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, float *pmaxMipmapLevelClamp, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, float *pmaxMipmapLevelClamp, CUtexref hTexRef); /** * \brief Gets the maximum anisotropy for a texture reference * + * \deprecated + * * Returns the maximum anisotropy in \p pmaxAniso that's used when reading memory through * the texture reference \p hTexRef. * @@ -13143,11 +17570,13 @@ CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, float * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef); /** * \brief Gets the border color used by a texture reference * + * \deprecated + * * Returns in \p pBorderColor, values of the RGBA color used by * the texture reference \p hTexRef. * The color value is of type float and holds color components in @@ -13170,11 +17599,13 @@ CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef); * \sa ::cuTexRefSetAddressMode, * ::cuTexRefSetAddressMode, ::cuTexRefSetBorderColor */ -CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef); /** * \brief Gets the flags used by a texture reference * + * \deprecated + * * Returns in \p *pFlags the flags of the texture reference \p hTexRef. * * \param pFlags - Returned flags @@ -13193,21 +17624,7 @@ CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef); * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, * ::cuTexRefGetFilterMode, ::cuTexRefGetFormat */ -CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef); - -/** @} */ /* END CUDA_TEXREF */ - -/** - * \defgroup CUDA_TEXREF_DEPRECATED Texture Reference Management [DEPRECATED] - * - * ___MANBRIEF___ deprecated texture reference management functions of the - * low-level CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ - * - * This section describes the deprecated texture reference management - * functions of the low-level CUDA driver application programming interface. - * - * @{ - */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef); /** * \brief Creates a texture reference @@ -13258,7 +17675,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); /** - * \defgroup CUDA_SURFREF Surface Reference Management + * \defgroup CUDA_SURFREF_DEPRECATED Surface Reference Management [DEPRECATED] * * ___MANBRIEF___ surface reference management functions of the low-level CUDA * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ @@ -13272,6 +17689,8 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); /** * \brief Sets the CUDA array for a surface reference. * + * \deprecated + * * Sets the CUDA array \p hArray to be read and written by the surface reference * \p hSurfRef. Any previous CUDA array state associated with the surface * reference is superseded by this function. \p Flags must be set to 0. @@ -13294,11 +17713,13 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); * ::cuSurfRefGetArray, * ::cudaBindSurfaceToArray */ -CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, unsigned int Flags); +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, unsigned int Flags); /** * \brief Passes back the CUDA array bound to a surface reference. * + * \deprecated + * * Returns in \p *phArray the CUDA array bound to the surface reference * \p hSurfRef, or returns ::CUDA_ERROR_INVALID_VALUE if the surface reference * is not bound to any CUDA array. @@ -13315,11 +17736,10 @@ CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, unsigned * * \sa ::cuModuleGetSurfRef, ::cuSurfRefSetArray */ -CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef); +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef); -/** @} */ /* END CUDA_SURFREF */ +/** @} */ /* END CUDA_SURFREF_DEPRECATED */ -#if __CUDA_API_VERSION >= 5000 /** * \defgroup CUDA_TEXOBJECT Texture Object Management * @@ -13457,11 +17877,19 @@ CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef); * This is ignored if ::CUDA_RESOURCE_DESC::resType is ::CU_RESOURCE_TYPE_LINEAR. * * - ::CUDA_TEXTURE_DESC::flags can be any combination of the following: - * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of having the texture promote integer data to floating point data in the - * range [0, 1]. Note that texture with 32-bit integer format would not be promoted, regardless of whether or not this flag is specified. - * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the default behavior of having the texture coordinates range from [0, Dim) where Dim is - * the width or height of the CUDA array. Instead, the texture coordinates [0, 1.0) reference the entire breadth of the array dimension; Note - * that for CUDA mipmapped arrays, this flag has to be set. + * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of + * having the texture promote integer data to floating point data in the + * range [0, 1]. Note that texture with 32-bit integer format would not be + * promoted, regardless of whether or not this flag is specified. + * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the default behavior + * of having the texture coordinates range from [0, Dim) where Dim is the + * width or height of the CUDA array. Instead, the texture coordinates + * [0, 1.0) reference the entire breadth of the array dimension; Note that + * for CUDA mipmapped arrays, this flag has to be set. + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. * * - ::CUDA_TEXTURE_DESC::maxAnisotropy specifies the maximum anisotropy ratio to be used when doing anisotropic filtering. This value will be * clamped to the range [1,16]. @@ -13710,7 +18138,6 @@ CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject); CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, CUsurfObject surfObject); /** @} */ /* END CUDA_SURFOBJECT */ -#endif /* __CUDA_API_VERSION >= 5000 */ /** * \defgroup CUDA_PEER_ACCESS Peer Context Memory Access @@ -13724,8 +18151,6 @@ CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, CUsur * @{ */ -#if __CUDA_API_VERSION >= 4000 - /** * \brief Queries if a device may directly access a peer device's memory. * @@ -13767,7 +18192,9 @@ CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, CUdevic * memory from the current context in \p peerContext, a separate symmetric call * to ::cuCtxEnablePeerAccess() is required. * - * There is a system-wide maximum of eight peer connections per device. + * Note that there are both device-wide and system-wide limitations per system + * configuration, as noted in the CUDA Programming Guide under the section + * "Peer-to-Peer Memory Access". * * Returns ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED if ::cuDeviceCanAccessPeer() indicates * that the ::CUdevice of the current context cannot directly access memory @@ -13832,10 +18259,6 @@ CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int Flags */ CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext); -#endif /* __CUDA_API_VERSION >= 4000 */ - -#if __CUDA_API_VERSION >= 8000 - /** * \brief Queries attributes of the link between two devices. * @@ -13876,8 +18299,6 @@ CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext); */ CUresult CUDAAPI cuDeviceGetP2PAttribute(int* value, CUdevice_P2PAttribute attrib, CUdevice srcDevice, CUdevice dstDevice); -#endif /* __CUDA_API_VERSION >= 8000 */ - /** @} */ /* END CUDA_PEER_ACCESS */ /** @@ -13962,8 +18383,6 @@ CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource); */ CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray(CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, unsigned int mipLevel); -#if __CUDA_API_VERSION >= 5000 - /** * \brief Get a mipmapped array through which to access a mapped graphics resource. * @@ -13995,9 +18414,6 @@ CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray(CUarray *pArray, CUgraphics */ CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray(CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource); -#endif /* __CUDA_API_VERSION >= 5000 */ - -#if __CUDA_API_VERSION >= 3020 /** * \brief Get a device pointer through which to access a mapped graphics resource. * @@ -14031,7 +18447,6 @@ CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray(CUmipmappedArray *pMi * ::cudaGraphicsResourceGetMappedPointer */ CUresult CUDAAPI cuGraphicsResourceGetMappedPointer(CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource); -#endif /* __CUDA_API_VERSION >= 3020 */ /** * \brief Set usage flags for mapping a graphics resource @@ -14153,8 +18568,73 @@ CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource /** @} */ /* END CUDA_GRAPHICS */ -CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExportTableId); +/** + * \defgroup CUDA_DRIVER_ENTRY_POINT Driver Entry Point Access + * + * ___MANBRIEF___ driver entry point access functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the driver entry point access functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ +/** + * \brief Returns the requested driver API function pointer + * + * Returns in \p **pfn the address of the CUDA driver function for the requested + * CUDA version and flags. + * + * The CUDA version is specified as (1000 * major + 10 * minor), so CUDA 11.2 + * should be specified as 11020. For a requested driver symbol, if the specified + * CUDA version is greater than or equal to the CUDA version in which the driver symbol + * was introduced, this API will return the function pointer to the corresponding + * versioned function. + * + * The pointer returned by the API should be cast to a function pointer matching the + * requested driver function's definition in the API header file. The function pointer + * typedef can be picked up from the corresponding typedefs header file. For example, + * cudaTypedefs.h consists of function pointer typedefs for driver APIs defined in cuda.h. + * + * The API will return ::CUDA_ERROR_NOT_FOUND if the requested driver function is not + * supported on the platform, no ABI compatible driver function exists for the specified + * \p cudaVersion or if the driver symbol is invalid. + * + * The requested flags can be: + * - ::CU_GET_PROC_ADDRESS_DEFAULT: This is the default mode. This is equivalent to + * ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM if the code is compiled with + * --default-stream per-thread compilation flag or the macro CUDA_API_PER_THREAD_DEFAULT_STREAM + * is defined; ::CU_GET_PROC_ADDRESS_LEGACY_STREAM otherwise. + * - ::CU_GET_PROC_ADDRESS_LEGACY_STREAM: This will enable the search for all driver symbols + * that match the requested driver symbol name except the corresponding per-thread versions. + * - ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM: This will enable the search for all + * driver symbols that match the requested driver symbol name including the per-thread + * versions. If a per-thread version is not found, the API will return the legacy version + * of the driver function. + * + * \param symbol - The base name of the driver API function to look for. As an example, + * for the driver API ::cuMemAlloc_v2, \p symbol would be cuMemAlloc and + * \p cudaVersion would be the ABI compatible CUDA version for the _v2 variant. + * \param pfn - Location to return the function pointer to the requested driver function + * \param cudaVersion - The CUDA version to look for the requested driver symbol + * \param flags - Flags to specify search options. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_NOT_FOUND + * \note_version_mixing + * + * \sa + * ::cudaGetDriverEntryPoint + */ +CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); + +/** @} */ /* END CUDA_DRIVER_ENTRY_POINT */ + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExportTableId); /** * CUDA API versioning support @@ -14233,6 +18713,7 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuStreamQuery #undef cuStreamSynchronize #undef cuEventRecord + #undef cuEventRecordWithFlags #undef cuLaunchKernel #undef cuLaunchHostFunc #undef cuGraphicsMapResources @@ -14249,184 +18730,160 @@ CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExp #undef cuStreamBeginCapture #undef cuStreamEndCapture #undef cuStreamIsCapturing + #undef cuStreamGetCaptureInfo + #undef cuStreamGetCaptureInfo_v2 + #undef cuGraphUpload #undef cuGraphLaunch -#endif /* __CUDA_API_VERSION_INTERNAL */ + #undef cuDevicePrimaryCtxRelease + #undef cuDevicePrimaryCtxReset + #undef cuDevicePrimaryCtxSetFlags + #undef cuIpcOpenMemHandle + #undef cuStreamCopyAttributes + #undef cuStreamSetAttribute + #undef cuStreamGetAttribute + #undef cuGraphInstantiate + #undef cuMemMapArrayAsync + #undef cuMemFreeAsync + #undef cuMemAllocAsync + #undef cuMemAllocFromPoolAsync + #undef cuStreamUpdateCaptureDependencies -#if defined(__CUDA_API_VERSION_INTERNAL) || (__CUDA_API_VERSION >= 4000 && __CUDA_API_VERSION < 6050) -CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, unsigned int Flags); -#endif /* defined(__CUDA_API_VERSION_INTERNAL) || (__CUDA_API_VERSION >= 4000 && __CUDA_API_VERSION < 6050) */ + CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, unsigned int Flags); + CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, unsigned int flags); + CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues, CUlinkState *stateOut); + CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, void **optionValues); + CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, const char *path, + unsigned int numOptions, CUjit_option *options, void **optionValues); + CUresult CUDAAPI cuTexRefSetAddress2D_v2(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION < 6050 -CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, unsigned int flags); -#endif /* defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION < 6050 */ + typedef unsigned int CUdeviceptr_v1; -#if defined(__CUDA_API_VERSION_INTERNAL) || (__CUDA_API_VERSION >= 5050 && __CUDA_API_VERSION < 6050) -CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues, CUlinkState *stateOut); -CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, const char *name, - unsigned int numOptions, CUjit_option *options, void **optionValues); -CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, const char *path, - unsigned int numOptions, CUjit_option *options, void **optionValues); -#endif /* __CUDA_API_VERSION_INTERNAL || (__CUDA_API_VERSION >= 5050 && __CUDA_API_VERSION < 6050) */ + typedef struct CUDA_MEMCPY2D_v1_st + { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ -#if defined(__CUDA_API_VERSION_INTERNAL) || (__CUDA_API_VERSION >= 3020 && __CUDA_API_VERSION < 4010) -CUresult CUDAAPI cuTexRefSetAddress2D_v2(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); -#endif /* __CUDA_API_VERSION_INTERNAL || (__CUDA_API_VERSION >= 3020 && __CUDA_API_VERSION < 4010) */ + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ -/** - * CUDA API made obselete at API version 3020 - */ -#if defined(__CUDA_API_VERSION_INTERNAL) - #define CUdeviceptr CUdeviceptr_v1 - #define CUDA_MEMCPY2D_st CUDA_MEMCPY2D_v1_st - #define CUDA_MEMCPY2D CUDA_MEMCPY2D_v1 - #define CUDA_MEMCPY3D_st CUDA_MEMCPY3D_v1_st - #define CUDA_MEMCPY3D CUDA_MEMCPY3D_v1 - #define CUDA_ARRAY_DESCRIPTOR_st CUDA_ARRAY_DESCRIPTOR_v1_st - #define CUDA_ARRAY_DESCRIPTOR CUDA_ARRAY_DESCRIPTOR_v1 - #define CUDA_ARRAY3D_DESCRIPTOR_st CUDA_ARRAY3D_DESCRIPTOR_v1_st - #define CUDA_ARRAY3D_DESCRIPTOR CUDA_ARRAY3D_DESCRIPTOR_v1 -#endif /* CUDA_FORCE_LEGACY32_INTERNAL */ + unsigned int WidthInBytes; /**< Width of 2D memory copy in bytes */ + unsigned int Height; /**< Height of 2D memory copy */ + } CUDA_MEMCPY2D_v1; -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION < 3020 + typedef struct CUDA_MEMCPY3D_v1_st + { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + unsigned int srcZ; /**< Source Z */ + unsigned int srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + void *reserved0; /**< Must be NULL */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + unsigned int srcHeight; /**< Source height (ignored when src is array; may be 0 if Depth==1) */ -typedef unsigned int CUdeviceptr; + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + unsigned int dstZ; /**< Destination Z */ + unsigned int dstLOD; /**< Destination LOD */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + void *reserved1; /**< Must be NULL */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + unsigned int dstHeight; /**< Destination height (ignored when dst is array; may be 0 if Depth==1) */ -typedef struct CUDA_MEMCPY2D_st -{ - unsigned int srcXInBytes; /**< Source X in bytes */ - unsigned int srcY; /**< Source Y */ - CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ - const void *srcHost; /**< Source host pointer */ - CUdeviceptr srcDevice; /**< Source device pointer */ - CUarray srcArray; /**< Source array reference */ - unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + unsigned int WidthInBytes; /**< Width of 3D memory copy in bytes */ + unsigned int Height; /**< Height of 3D memory copy */ + unsigned int Depth; /**< Depth of 3D memory copy */ + } CUDA_MEMCPY3D_v1; - unsigned int dstXInBytes; /**< Destination X in bytes */ - unsigned int dstY; /**< Destination Y */ - CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ - void *dstHost; /**< Destination host pointer */ - CUdeviceptr dstDevice; /**< Destination device pointer */ - CUarray dstArray; /**< Destination array reference */ - unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + typedef struct CUDA_ARRAY_DESCRIPTOR_v1_st + { + unsigned int Width; /**< Width of array */ + unsigned int Height; /**< Height of array */ - unsigned int WidthInBytes; /**< Width of 2D memory copy in bytes */ - unsigned int Height; /**< Height of 2D memory copy */ -} CUDA_MEMCPY2D; + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + } CUDA_ARRAY_DESCRIPTOR_v1; -typedef struct CUDA_MEMCPY3D_st -{ - unsigned int srcXInBytes; /**< Source X in bytes */ - unsigned int srcY; /**< Source Y */ - unsigned int srcZ; /**< Source Z */ - unsigned int srcLOD; /**< Source LOD */ - CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ - const void *srcHost; /**< Source host pointer */ - CUdeviceptr srcDevice; /**< Source device pointer */ - CUarray srcArray; /**< Source array reference */ - void *reserved0; /**< Must be NULL */ - unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ - unsigned int srcHeight; /**< Source height (ignored when src is array; may be 0 if Depth==1) */ + typedef struct CUDA_ARRAY3D_DESCRIPTOR_v1_st + { + unsigned int Width; /**< Width of 3D array */ + unsigned int Height; /**< Height of 3D array */ + unsigned int Depth; /**< Depth of 3D array */ - unsigned int dstXInBytes; /**< Destination X in bytes */ - unsigned int dstY; /**< Destination Y */ - unsigned int dstZ; /**< Destination Z */ - unsigned int dstLOD; /**< Destination LOD */ - CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ - void *dstHost; /**< Destination host pointer */ - CUdeviceptr dstDevice; /**< Destination device pointer */ - CUarray dstArray; /**< Destination array reference */ - void *reserved1; /**< Must be NULL */ - unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ - unsigned int dstHeight; /**< Destination height (ignored when dst is array; may be 0 if Depth==1) */ + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + unsigned int Flags; /**< Flags */ + } CUDA_ARRAY3D_DESCRIPTOR_v1; - unsigned int WidthInBytes; /**< Width of 3D memory copy in bytes */ - unsigned int Height; /**< Height of 3D memory copy */ - unsigned int Depth; /**< Depth of 3D memory copy */ -} CUDA_MEMCPY3D; + CUresult CUDAAPI cuDeviceTotalMem(unsigned int *bytes, CUdevice dev); + CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); + CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr_v1 *dptr, unsigned int *bytes, CUmodule hmod, const char *name); + CUresult CUDAAPI cuMemGetInfo(unsigned int *free, unsigned int *total); + CUresult CUDAAPI cuMemAlloc(CUdeviceptr_v1 *dptr, unsigned int bytesize); + CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr_v1 *dptr, unsigned int *pPitch, unsigned int WidthInBytes, unsigned int Height, unsigned int ElementSizeBytes); + CUresult CUDAAPI cuMemFree(CUdeviceptr_v1 dptr); + CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr_v1 *pbase, unsigned int *psize, CUdeviceptr_v1 dptr); + CUresult CUDAAPI cuMemAllocHost(void **pp, unsigned int bytesize); + CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr_v1 *pdptr, void *p, unsigned int Flags); + CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr_v1 dstDevice, const void *srcHost, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr_v1 dstDevice, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, unsigned int dstOffset, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr_v1 dstDevice, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, unsigned int dstOffset, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D_v1 *pCopy); + CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D_v1 *pCopy); + CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D_v1 *pCopy); + CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr_v1 dstDevice, const void *srcHost, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr_v1 srcDevice, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr_v1 dstDevice, CUdeviceptr_v1 srcDevice, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D_v1 *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D_v1 *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemsetD8(CUdeviceptr_v1 dstDevice, unsigned char uc, unsigned int N); + CUresult CUDAAPI cuMemsetD16(CUdeviceptr_v1 dstDevice, unsigned short us, unsigned int N); + CUresult CUDAAPI cuMemsetD32(CUdeviceptr_v1 dstDevice, unsigned int ui, unsigned int N); + CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned char uc, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned short us, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned int ui, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, const CUDA_ARRAY_DESCRIPTOR_v1 *pAllocateArray); + CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); + CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR_v1 *pAllocateArray); + CUresult CUDAAPI cuArray3DGetDescriptor(CUDA_ARRAY3D_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); + CUresult CUDAAPI cuTexRefSetAddress(unsigned int *ByteOffset, CUtexref hTexRef, CUdeviceptr_v1 dptr, unsigned int bytes); + CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR_v1 *desc, CUdeviceptr_v1 dptr, unsigned int Pitch); + CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr_v1 *pdptr, CUtexref hTexRef); + CUresult CUDAAPI cuGraphicsResourceGetMappedPointer(CUdeviceptr_v1 *pDevPtr, unsigned int *pSize, CUgraphicsResource resource); -typedef struct CUDA_ARRAY_DESCRIPTOR_st -{ - unsigned int Width; /**< Width of array */ - unsigned int Height; /**< Height of array */ + CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); + CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); + CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); + CUresult CUDAAPI cuStreamDestroy(CUstream hStream); + CUresult CUDAAPI cuEventDestroy(CUevent hEvent); + CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); + CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); + CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); - CUarray_format Format; /**< Array format */ - unsigned int NumChannels; /**< Channels per array element */ -} CUDA_ARRAY_DESCRIPTOR; - -typedef struct CUDA_ARRAY3D_DESCRIPTOR_st -{ - unsigned int Width; /**< Width of 3D array */ - unsigned int Height; /**< Height of 3D array */ - unsigned int Depth; /**< Depth of 3D array */ - - CUarray_format Format; /**< Array format */ - unsigned int NumChannels; /**< Channels per array element */ - unsigned int Flags; /**< Flags */ -} CUDA_ARRAY3D_DESCRIPTOR; - -CUresult CUDAAPI cuDeviceTotalMem(unsigned int *bytes, CUdevice dev); -CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); -CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, unsigned int *bytes, CUmodule hmod, const char *name); -CUresult CUDAAPI cuMemGetInfo(unsigned int *free, unsigned int *total); -CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, unsigned int bytesize); -CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, unsigned int *pPitch, unsigned int WidthInBytes, unsigned int Height, unsigned int ElementSizeBytes); -CUresult CUDAAPI cuMemFree(CUdeviceptr dptr); -CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, unsigned int *psize, CUdeviceptr dptr); -CUresult CUDAAPI cuMemAllocHost(void **pp, unsigned int bytesize); -CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, unsigned int Flags); -CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, unsigned int dstOffset, CUdeviceptr srcDevice, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, unsigned int dstOffset, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); -CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount, CUstream hStream); -CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount, CUstream hStream); -CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy); -CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy); -CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy); -CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, unsigned int ByteCount, CUstream hStream); -CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, unsigned int ByteCount, CUstream hStream); -CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, unsigned int ByteCount, CUstream hStream); -CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream); -CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream); -CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, unsigned int N); -CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, unsigned int N); -CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, unsigned int N); -CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, unsigned int dstPitch, unsigned char uc, unsigned int Width, unsigned int Height); -CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, unsigned int dstPitch, unsigned short us, unsigned int Width, unsigned int Height); -CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, unsigned int dstPitch, unsigned int ui, unsigned int Width, unsigned int Height); -CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, const CUDA_ARRAY_DESCRIPTOR *pAllocateArray); -CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, CUarray hArray); -CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray); -CUresult CUDAAPI cuArray3DGetDescriptor(CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray); -CUresult CUDAAPI cuTexRefSetAddress(unsigned int *ByteOffset, CUtexref hTexRef, CUdeviceptr dptr, unsigned int bytes); -CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, unsigned int Pitch); -CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef); -CUresult CUDAAPI cuGraphicsResourceGetMappedPointer(CUdeviceptr *pDevPtr, unsigned int *pSize, CUgraphicsResource resource); -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION < 3020 */ -#if defined(__CUDA_API_VERSION_INTERNAL) || __CUDA_API_VERSION < 4000 -CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); -CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); -CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); -CUresult CUDAAPI cuStreamDestroy(CUstream hStream); -CUresult CUDAAPI cuEventDestroy(CUevent hEvent); -#endif /* __CUDA_API_VERSION_INTERNAL || __CUDA_API_VERSION < 4000 */ -#if defined(__CUDA_API_VERSION_INTERNAL) - #undef CUdeviceptr - #undef CUDA_MEMCPY2D_st - #undef CUDA_MEMCPY2D - #undef CUDA_MEMCPY3D_st - #undef CUDA_MEMCPY3D - #undef CUDA_ARRAY_DESCRIPTOR_st - #undef CUDA_ARRAY_DESCRIPTOR - #undef CUDA_ARRAY3D_DESCRIPTOR_st - #undef CUDA_ARRAY3D_DESCRIPTOR -#endif /* __CUDA_API_VERSION_INTERNAL */ - -#if defined(__CUDA_API_VERSION_INTERNAL) CUresult CUDAAPI cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); CUresult CUDAAPI cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); CUresult CUDAAPI cuMemcpyDtoD_v2(CUdeviceptr dstDevice, CUdeviceptr srcDevice, size_t ByteCount); @@ -14474,6 +18931,7 @@ CUresult CUDAAPI cuEventDestroy(CUevent hEvent); CUresult CUDAAPI cuStreamQuery(CUstream hStream); CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); + CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, unsigned int flags); CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, void *userData); CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); @@ -14488,16 +18946,49 @@ CUresult CUDAAPI cuEventDestroy(CUevent hEvent); CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); CUresult CUDAAPI cuWaitExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); + CUresult CUDAAPI cuStreamBeginCapture_ptsz(CUstream hStream); + CUresult CUDAAPI cuStreamBeginCapture_v2(CUstream hStream, CUstreamCaptureMode mode); CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); + CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); + CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); + CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraph, CUstream hStream); CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraph, CUstream hStream); + CUresult CUDAAPI cuStreamCopyAttributes(CUstream dstStream, CUstream srcStream); + CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, CUstreamAttrValue *value); + CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, const CUstreamAttrValue *param); + + CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int Flags); + CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); + CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, unsigned int count, CUstream hStream); + + CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); + CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, CUstream hStream); + CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, CUmemoryPool pool, CUstream hStream); + + CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); +#elif defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) +static inline CUresult cuGetProcAddress_ptsz(const char *symbol, void **funcPtr, int driverVersion, cuuint64_t flags) { + const int procAddressMask = (CU_GET_PROC_ADDRESS_LEGACY_STREAM| + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM); + if ((flags & procAddressMask) == 0) { + flags |= CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM; + } + return cuGetProcAddress(symbol, funcPtr, driverVersion, flags); +} +#define cuGetProcAddress cuGetProcAddress_ptsz #endif #ifdef __cplusplus } #endif -#undef __CUDA_API_VERSION +#if defined(__GNUC__) + #if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) + #pragma GCC visibility pop + #endif +#endif + #undef __CUDA_DEPRECATED #endif /* __cuda_cuda_h__ */ diff --git a/python/src/triton.cc b/python/src/triton.cc index 4d7df76ff..cb6923afe 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -101,12 +101,12 @@ void hip_enqueue(uint64_t stream, uint64_t kernel, } -std::string pow2_divisor(long N){ - if(N % 16 == 0) return "16"; - if(N % 8 == 0) return "8"; - if(N % 4 == 0) return "4"; - if(N % 2 == 0) return "2"; - return "1"; +long pow2_divisor(long N){ + if(N % 16 == 0) return 16; + if(N % 8 == 0) return 8; + if(N % 4 == 0) return 4; + if(N % 2 == 0) return 2; + return 1; } // Returns something like "int16", whether dtype is a torch.dtype or @@ -127,6 +127,14 @@ std::string dtype_cache_key_part(const py::object& dtype) { } } +size_t get_pointer_range_size(uint64_t addr){ + if(addr == 0) + return 0; + size_t size; + drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr); + return size; +} + // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, @@ -187,7 +195,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; // values divisible by small powers of 2 are specialized cache_key += "[multipleof("; - cache_key += pow2_divisor(value); + cache_key += std::to_string(pow2_divisor(value)); cache_key += ")]"; continue; } @@ -213,12 +221,15 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f py::object data_ptr = arg.attr("data_ptr")(); long value = data_ptr.cast(); params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + // copy param std::memcpy(params_ptr, &value, 8); params_ptr += 8; + // udpate cache key cache_key += dtype_cache_key_part(arg.attr("dtype")); cache_key += "*"; cache_key += "[multipleof("; - cache_key += pow2_divisor(value); + size_t range_size = get_pointer_range_size(value); + cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size))); cache_key += ")]"; continue; } @@ -268,6 +279,10 @@ void init_triton_runtime(py::module &&m) { } ); + // get range size for the given pointer + m.def("get_pointer_range_size", &get_pointer_range_size); + + // cache key m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b7da2047e..63e3d0aa0 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -674,9 +674,17 @@ class Kernel: def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes - args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] - attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) - if isinstance(a, int) and i not in self.fn.do_not_specialize} + attributes = dict() + for i, arg in enumerate(wargs): + if i in self.fn.do_not_specialize: + continue + if isinstance(arg, int): + attributes[i] = Kernel.pow2_divisor(arg) + elif i in tensor_idxs: + addr = arg.data_ptr() + range_size = _triton.runtime.get_pointer_range_size(addr) + attributes[i] = min(Kernel.pow2_divisor(addr), + Kernel.pow2_divisor(range_size)) # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 4ef6408d0..6f1bf3ba7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -165,6 +165,7 @@ class block: self.numel = 1 for s in self.shape: self.numel *= s + self.numel = constexpr(self.numel) # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) @@ -873,7 +874,7 @@ def ravel(x): :param x: the input block :type x: Block """ - return triton.language.reshape(x, [x.type.numel]) + return triton.language.reshape(x, [x.numel]) @triton.jit From 2a944ded53f73b124563a7130a948bdefd23e627 Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 14 Jan 2022 15:38:32 +0800 Subject: [PATCH 049/215] [TESTS] Added bfloat16 tests (#430) --- python/test/unit/language/test_core.py | 2 ++ python/test/unit/operators/test_matmul.py | 12 +++++++++--- python/triton/code_gen.py | 2 +- python/triton/ops/matmul.py | 2 +- python/triton/ops/matmul_perf_model.py | 6 +++++- python/triton/testing.py | 4 ++++ 6 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e32622005..64f60e260 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -732,6 +732,8 @@ def test_dot(epilogue, allow_tf32, device='cuda'): assert 'st.global.v4' in ptx if allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx + else: + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx def test_dot_without_load(): diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 1d413a0e6..514fbab7b 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -4,6 +4,7 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize( @@ -48,7 +49,7 @@ import triton (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE), - ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ @@ -61,11 +62,16 @@ import triton # split-k (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE), - ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] ] ), ) def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and DTYPE == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + if DTYPE == "bfloat16" and SPLIT_K != 1: + pytest.skip("bfloat16 matmuls don't allow split_k for now") torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} @@ -81,7 +87,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, N = BLOCK_N if N is None else N K = BLOCK_K * SPLIT_K if K is None else K # allocate/transpose inputs - DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] + DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE] a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 63e3d0aa0..960df4efc 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -838,7 +838,7 @@ class Autotuner: # prune configs pruned_configs = self.configs if self.prune_num_stages_by: - pruned_configs = self.prune_num_stages_by(self.configs) + pruned_configs = self.prune_num_stages_by(self.configs, self.nargs) if self.perf_model: top_k = self.configs_top_k if isinstance(top_k, float) and top_k <= 1.0: diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index d7af57406..9466b9ba7 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -87,7 +87,7 @@ def _kernel(A, B, C, M, N, K, acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(tl.float16) + acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index af4f3eed8..98a85bc85 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -78,12 +78,16 @@ def estimate_matmul_time( return total_time_ms -def prune_num_stages(configs): +def prune_num_stages(configs, named_args): backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() cc = _triton.runtime.cc(backend, device) # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + # Some dtypes do not allow atomic_add + if named_args['A'].dtype == torch.bfloat16: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) configs_map = {} for config in configs: diff --git a/python/triton/testing.py b/python/triton/testing.py index 310e754ed..199226ea1 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -65,8 +65,12 @@ def mask_tensor(x, mask, block, value=0): def assert_almost_equal(x, y, decimal=2, err_msg=''): import numpy.testing as npt if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() x = x.cpu().detach().numpy() if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() y = y.cpu().detach().numpy() npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) From e0c5709cc8119d4b1b1c334b2d1c0923001da4c4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 17 Jan 2022 18:00:03 -0800 Subject: [PATCH 050/215] [FRONTEND] Fixed semantics bug on ptr to bool conversions (#432) --- lib/ir/dispatch.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 12be53bfc..73d4ddc94 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -527,8 +527,15 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build else return builder->create_si_to_fp(input, dst_ty); } - if (src_sca_ty->is_pointer_ty() && !dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::PtrToInt, input, dst_ty); + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ + int bitwidth = dst_sca_ty->get_integer_bitwidth(); + if(bitwidth == 64) + return builder->create_cast(ir::PtrToInt, input, dst_ty); + if(bitwidth == 1) + return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), + builder->get_int64(0), + builder); + } if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) return builder->create_cast(ir::IntToPtr, input, dst_ty); // Ptr -> Ptr From 4c97d1ecd778c287e57617eb01c4edd03cf855f6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 20 Jan 2022 10:55:59 -0800 Subject: [PATCH 051/215] [FRONTEND] Bunch of fixes here and there (#436) --- lib/ir/dispatch.cc | 7 +++- python/src/triton.cc | 6 ++- python/test/unit/language/test_core.py | 2 +- python/test/unit/runtime/test_cache.py | 6 +-- python/triton/code_gen.py | 52 ++++++++++++++------------ python/triton/language/core.py | 19 +++++++++- python/triton/language/random.py | 18 ++++++--- 7 files changed, 71 insertions(+), 39 deletions(-) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 73d4ddc94..69c76b5e5 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -745,8 +745,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir: x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); } - if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty()) - throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty()); + ir::type* x_ty = x->get_type()->get_scalar_ty(); + ir::type* y_ty = y->get_type()->get_scalar_ty(); + ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); + x = dispatch::cast(x, ty, builder); + y = dispatch::cast(y, ty, builder); return builder->create_select(condition, x, y); } diff --git a/python/src/triton.cc b/python/src/triton.cc index cb6923afe..8ceb14200 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -329,7 +329,6 @@ void init_triton_runtime(py::module &&m) { // cuda will block if too many ops are enqueued Py_BEGIN_ALLOW_THREADS - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); @@ -466,6 +465,9 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, uint64_t device, int num_warps, int num_stages, asm_map_t &asm_map){ + + int n_shared_bytes; + Py_BEGIN_ALLOW_THREADS llvm::LLVMContext ctx; // device properties CUdevice dev = (CUdevice)device; @@ -476,7 +478,6 @@ std::tuple cu_compile_ttir(const std::string& name, drv::dispatch::cuDriverGetVersion(&version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); - int n_shared_bytes; auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); std::string tmp; llvm::raw_string_ostream llir(tmp); @@ -492,6 +493,7 @@ std::tuple cu_compile_ttir(const std::string& name, py::bytes bytes(cubin); asm_map["cubin"] = bytes; } + Py_END_ALLOW_THREADS return std::make_tuple(name, asm_map, n_shared_bytes); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 64f60e260..4d6e32aa6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -79,7 +79,7 @@ def to_numpy(x): def patch_kernel(template, to_replace): - kernel = copy.deepcopy(template) + kernel = triton.JITFunction(template.fn) for key, value in to_replace.items(): kernel.src = kernel.src.replace(key, value) return kernel diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 51c69b5b6..48797b51a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -39,9 +39,9 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr): def apply_src_change(target, old, new): - delattr(kernel.fn, 'hash') - delattr(function_1.fn, 'hash') - delattr(function_2.fn, 'hash') + kernel.hash = None + function_1.hash = None + function_2.hash = None function_1.src = function_1.src.replace(old, new) target.src = target.src.replace(old, new) ret = target.cache_key diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 960df4efc..347635e32 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -651,7 +651,7 @@ class Kernel: prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self.fn into code-generator object - gscope = self.fn.fn.__globals__ + gscope = self.fn.__globals__ generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) @@ -723,7 +723,7 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) if JITFunction.cache_hook is not None: - name = self.fn.fn.__name__ + name = self.fn.__name__ info = key.split('-')[-3:] num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] # make signature human-readable @@ -885,8 +885,6 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -# 3 - class DependenciesFinder(ast.NodeVisitor): @@ -910,17 +908,14 @@ class DependenciesFinder(ast.NodeVisitor): func = self.visit(node.func) if func is None: return - if isinstance(func, triton.JITFunction): - func = func.fn - module = inspect.getmodule(func) - if module and module.__name__.startswith('triton.'): - return if inspect.isbuiltin(func): return - if not hasattr(func, 'hash'): - src = textwrap.dedent(inspect.getsource(func)) - tree = ast.parse(src) - finder = DependenciesFinder(func.__globals__, src) + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, triton.JITFunction) + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) finder.visit(tree) func.hash = finder.ret self.ret = (self.ret + func.hash).encode("utf-8") @@ -941,10 +936,12 @@ class JITFunction: self.version = version self.src = textwrap.dedent(inspect.getsource(fn)) - self.do_not_specialize = [] if do_not_specialize is None else\ - [self.arg_names.index(arg) for arg in do_not_specialize] + self.src = self.src[self.src.find("def"):] + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] # cache for callable driver objects (e.g. CUkernel) self.bin_cache = dict() + self.hash = None # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = [] @@ -954,15 +951,19 @@ class JITFunction: self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ @property @functools.lru_cache() def cache_key(self): - if not hasattr(self.fn, 'hash'): - dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) dependencies_finder.visit(self.parse()) - self.fn.hash = dependencies_finder.ret + version_key() - return self.fn.hash + self.hash = dependencies_finder.ret + version_key() + return self.hash # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. @@ -974,14 +975,20 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator): + def __call__(self, *args, generator: CodeGenerator, **kwargs): try: + from inspect import getcallargs + arg_values = getcallargs(self.fn, *args, **kwargs) + arg_values = [arg_values[name] for name in self.arg_names] + arg_values = [arg if isinstance(arg, triton.language.block) + else triton.language.constexpr(arg) for arg in arg_values] + gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() - ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) @@ -1001,8 +1008,7 @@ class JITFunction: self.kernel = None super(JITFunction, self).__setattr__(name, value) if name == 'src': - if hasattr(self.fn, 'hash'): - delattr(self.fn, 'hash') + self.hash = None JITFunction.cache_key.fget.cache_clear() def _init_kernel(self): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6f1bf3ba7..df3b1f4cf 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -168,6 +168,8 @@ class block: self.numel = constexpr(self.numel) # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) + # Shape is a constexpr + self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: # ex. "float32[3,4]" @@ -297,7 +299,7 @@ class block: if sl is None: dst_shape.append(1) elif sl == slice(None, None, None): - dst_shape.append(src_shape[curr]) + dst_shape.append(src_shape[curr].value) curr += 1 ret = frontend.reshape(self, dst_shape, _builder) return ret @@ -320,8 +322,15 @@ class constexpr: """ def __init__(self, value): - self.value = value + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + # def __add__(self, other): return self.value + other.value @@ -516,6 +525,7 @@ def reshape(input, shape, _builder=None): :type shape: Tuple[int] """ + shape = [x.value for x in shape] return frontend.reshape(input, shape, _builder) @@ -908,3 +918,8 @@ def swizzle2d(i, j, size_i, size_j, size_g): new_i = off_i + (ij % size_g) new_j = (ij % size_gj) // size_g return new_i, new_j + + +@triton.jit +def zeros_like(input): + return zeros(input.shape, input.dtype) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 69d7f4c4d..d5691bb72 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -13,7 +13,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox @triton.jit -def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). """ @@ -32,6 +32,14 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): return c0, c1, c2, c3 +@triton.jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = seed.to(tl.uint64) + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + @triton.jit def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ @@ -60,11 +68,9 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting - seed = seed.to(tl.uint64) - seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) - seed_lo = (seed & 0xffffffff).to(tl.uint32) - return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) # ------------------- From ccf9abe0ba081f13d37c9966a466f9984cd92747 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 21 Jan 2022 18:05:55 -0800 Subject: [PATCH 052/215] [FRONTEND][RANDOM] Improved backward compatibility of RNG (#438) The unsigned int PR definitely improved our RNG. However, it requires different floating point arithmetics which, means the results are not bit-wise identical to how they were before. This commit revives backward compatibility, but we should change it back to the "right" way later. --- python/triton/language/random.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index d5691bb72..c95eac9fc 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -77,13 +77,23 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # rand # ------------------- +# @triton.jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + @triton.jit def uint32_to_uniform_float(x): """ Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ - two_to_the_minus_32: tl.constexpr = 2.328306e-10 - return x * two_to_the_minus_32 + x = x.to(tl.int32, bitcast=True) + max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + x = tl.where(x < 0, -x - 1, x) + return x * max @triton.jit From 3a23c1dd332718f3e62210280de707e496d57f31 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Sun, 23 Jan 2022 17:24:02 -0500 Subject: [PATCH 053/215] [BACKEND] minor, hotfix for gcc compilation (#439) --- include/triton/ir/type.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index c27ce48cf..47c9b5f85 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace triton{ namespace ir{ From 59d371c6eb3e8f904c9bc879e5df07b0b79131ad Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 28 Jan 2022 01:12:44 +0800 Subject: [PATCH 054/215] [BACKEND] Added Int8 mma (#440) --- include/triton/codegen/analysis/layout.h | 2 + lib/codegen/analysis/layout.cc | 24 +++- lib/codegen/analysis/swizzle.cc | 6 + lib/codegen/selection/common.h | 78 ---------- lib/codegen/selection/generator.cc | 175 +++++++++++++++++++++-- lib/codegen/transform/peephole.cc | 15 +- lib/codegen/transform/prefetch.cc | 3 + lib/driver/llvm.cc | 6 +- lib/ir/dispatch.cc | 6 +- python/setup.py | 2 +- python/test/unit/language/test_core.py | 30 ++-- 11 files changed, 232 insertions(+), 115 deletions(-) delete mode 100644 lib/codegen/selection/common.h diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index dc5150f05..b6376d7cc 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -268,6 +268,7 @@ public: void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } int get_mma_vec() { return mma_vec_;} int get_mma_strided() { return mma_strided_; } + bool allow_swizzle() const { return allow_swizzle_; } data_layout* get_arg_layout() { return arg_layout_; } private: @@ -281,6 +282,7 @@ private: data_layout* arg_layout_; int mma_vec_; int mma_strided_; + bool allow_swizzle_ = true; target *tgt_; }; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index d00959e45..2206f5b6a 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){ result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && - x->allow_tf32() && sm >= 80); + x->allow_tf32() && sm >= 80) || + (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) && + sm >= 80); } return result; } @@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { return mma_type; } } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { - throw std::runtime_error("integer tensor cores are not yet supported"); + // throw std::runtime_error("integer tensor cores are not yet supported"); // // integer tensor cores // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { // mma_type = mma_layout::INT32_INT1_INT1_INT32; @@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { // mma_type = mma_layout::INT32_INT4_INT4_INT32; // return mma_type; // } - // if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { - // mma_type = mma_layout::INT32_INT8_INT8_INT32; - // return mma_type; - // } + if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { + mma_type = mma_layout::INT32_INT8_INT8_INT32; + return mma_type; + } } } return mma_layout::NOT_APPLICABLE; @@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg, std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; + + // for now, disable swizzle when using lds.8 + if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32) + if (order_[0] == 0) // need transpose + allow_swizzle_ = false; } else if (hmma_dot_b_) { assert(order_.size() == 2); std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; + + // for now, disable swizzle when using lds.8 + if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32) + if (order_[0] == 1) // need transpose + allow_swizzle_ = false; } // size diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 1dbae10d4..5737f80a0 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -41,9 +41,15 @@ void swizzle::run(ir::module &) { vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { + if (!layout->allow_swizzle()) { + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; + } else { per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; vec_[layout] = layout->get_mma_vec(); + } } } } diff --git a/lib/codegen/selection/common.h b/lib/codegen/selection/common.h deleted file mode 100644 index c4b0951da..000000000 --- a/lib/codegen/selection/common.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include -#include -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/target.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/ir/context.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/type.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/InlineAsm.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" - -namespace triton::codegen { -// types -#define void_ty builder_->getVoidTy() -#define f16_ty builder_->getHalfTy() -#define bf16_ty builder_->getBFloatTy() -#define f32_ty builder_->getFloatTy() -#define i8_ty builder_->getInt8Ty() -#define i32_ty builder_->getInt32Ty() -#define vec_ty(type, num_el) VectorType::get(type, num_el, false) -#define ptr_ty(...) PointerType::get(__VA_ARGS__) -// constants -#define i32(...) builder_->getInt32(__VA_ARGS__) -// ops -#define and_(...) builder_->CreateAnd(__VA_ARGS__) -#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) -#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) -#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) -#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) -#define br(...) builder_->CreateBr(__VA_ARGS__) -#define call(...) builder_->CreateCall(__VA_ARGS__) -#define cast(...) builder_->CreateCast(__VA_ARGS__) -#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) -#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) -#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) -#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) -#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) -#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) -#define fmul(...) builder_->CreateFMul(__VA_ARGS__) -#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) -#define fsub(...) builder_->CreateFSub(__VA_ARGS__) -#define icmp(...) builder_->CreateICmp(__VA_ARGS__) -#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) -#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) -#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) -#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) -#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) -#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(...) builder_->CreateLoad(__VA_ARGS__) -#define lshr(...) builder_->CreateLShr(__VA_ARGS__) -#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) -#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) -#define neg(...) builder_->CreateNeg(__VA_ARGS__) -#define phi(...) builder_->CreatePHI(__VA_ARGS__) -#define ret(...) builder_->CreateRet(__VA_ARGS__) -#define select(...) builder_->CreateSelect(__VA_ARGS__) -#define store(...) builder_->CreateStore(__VA_ARGS__) -#define sub(...) builder_->CreateSub(__VA_ARGS__) -#define shl(...) builder_->CreateShl(__VA_ARGS__) -#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) -#define urem(...) builder_->CreateURem(__VA_ARGS__) -#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) -#define xor_(...) builder_->CreateXor(__VA_ARGS__) - -} // namespace triton::codegen \ No newline at end of file diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b180ecb12..a55991475 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include "triton/codegen/selection/generator.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/axes.h" @@ -1355,9 +1356,6 @@ public: need_trans_ = k_order_ != order_[0]; can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); - // std::cout << can_use_ldmatrix_ << std::endl; - // std::cout << need_trans_ << std::endl; - // we need more pointers at the fast-changing axis, if (can_use_ldmatrix_) num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; @@ -1365,6 +1363,9 @@ public: num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; num_ptr_ = std::max(num_ptr_, 2); + // special rule for i8/u8, 4 ptrs for each matrix + if (!can_use_ldmatrix_ && dtsize_ == 1) + num_ptr_ *= 4; // load_v4 stride (in num of mats) int load_stride_in_mat[2]; @@ -1445,6 +1446,46 @@ public: } return offs; // throw std::runtime_error("not implemented"); + } else if (dtsize_ == 1 && need_trans_) { + // load i8/u8 matrices with lds8 + Value *c_off_in_mat = udiv(lane, i32(4)); // + Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols + + // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + std::vector offs(num_ptr_); + for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time + int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; + int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; + if (k_mat_arr_int > 0) // we don't need pointers for k + continue; + Value *k_mat_arr = i32(k_mat_arr_int); + Value *nk_mat_arr = i32(nk_mat_arr_int); + // physical offset (before swizzling) + Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + Value *s_mat_off = k_mat_arr; // always 0? + + for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { + for (int elem_off = 0; elem_off < 4; ++elem_off) { + int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; + + Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); + Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); + + // disable swizzling ... + // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + // c_mat_off_i = xor_(c_mat_off_i, phase); + + Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); + Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); + // To prevent out-of-bound access when the tile is too small + c_off = urem(c_off, i32(tile_shape_[order_[0]])); + s_off = urem(s_off, i32(tile_shape_[order_[1]])); + offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); + } + } + } + return offs; } else throw std::runtime_error("invalid smem load config"); } @@ -1461,8 +1502,10 @@ public: int ptr_idx = -1; if (can_use_ldmatrix_) ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); - else // tf32 & trans + else if (dtsize_ == 4 && need_trans_) // tf32 & trans ptr_idx = mat_idx[order_[0]]; + else // i8 & trans + ptr_idx = mat_idx[order_[0]] * 4; auto get_ptr = [&](int idx) -> Value* { Value *ptr = nullptr; @@ -1495,9 +1538,7 @@ public: extract_val(res_v4, std::vector{1}), extract_val(res_v4, std::vector{2}), extract_val(res_v4, std::vector{3})}; - } else { - // assert(false && "should not be here"); - assert(dtsize_ == 4 && need_trans_); + } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices Value *ptr2 = get_ptr(ptr_idx+1); assert(s_mat_stride_ == 1); int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; @@ -1521,7 +1562,96 @@ public: prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); } return {elem0, elem1, elem2, elem3}; - } + } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices + Value *ptr00 = get_ptr(ptr_idx); + Value *ptr01 = get_ptr(ptr_idx+1); + Value *ptr02 = get_ptr(ptr_idx+2); + Value *ptr03 = get_ptr(ptr_idx+3); + + Value *ptr10 = get_ptr(ptr_idx+4); + Value *ptr11 = get_ptr(ptr_idx+5); + Value *ptr12 = get_ptr(ptr_idx+6); + Value *ptr13 = get_ptr(ptr_idx+7); + + assert(s_mat_stride_ == 1); + int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; + int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; + + Value *i8v4_elems[4]; + Value *i32_elems[4]; + for (int i=0; i<4; ++i) + i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); + + Value *elem00, *elem01, *elem02, *elem03; + Value *elem10, *elem11, *elem12, *elem13; + Value *elem20, *elem21, *elem22, *elem23; + Value *elem30, *elem31, *elem32, *elem33; + Value *i8_elems[4*4]; + if (k_order_ == 1) { // + i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); + i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); + i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); + i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); + + assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); + + i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); + i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); + i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); + i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); + + i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); + + i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); + + for (int m=0; m<4; ++m) { + for (int e=0; e<4; ++e) + i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); + i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); + } + } else { // for b (k first) + i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); + i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); + i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); + i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); + + assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); + + i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); + i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); + i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); + i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); + + i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); + + i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); + + for (int m=0; m<4; ++m) { + for (int e=0; e<4; ++e) + i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); + i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); + } + } + if (k == 0 && inc == 1 && is_prefetch) { + for (int m = 0; m < 4; ++m) + for (int e = 0; e < 4; ++e) + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); + } + return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; + } else + throw std::runtime_error("invalid smem load"); } int get_num_ptr() const { return num_ptr_; } @@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: const int num_rep_n = shapes[1] / layout->shape_per_cta(1); const int num_rep_k = std::max(NK/mma_instr_k, 1); + // floating point types Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); Type *bf16x2_ty = vec_ty(bf16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); + // integer types + Type *i8x4_ty = vec_ty(i8_ty, 4); + Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); + Type *i32_pack4_ty = StructType::get(*ctx_, std::vector{i32_ty, i32_ty, i32_ty, i32_ty}); + FunctionType *ldmatrix_ty = nullptr; FunctionType *mma_ty = nullptr; @@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: smem_ptr_ty = ptr_ty(fp32_ty, 3); ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp32_ty; + } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { + // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) + mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); + smem_ptr_ty = ptr_ty(i8_ty, 3); + ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = i32_ty; + // mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); + // smem_ptr_ty = ptr_ty(i8_ty, 3); + // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector{smem_ptr_ty}, false); + // phi_ty = i8x4_ty; } else throw std::runtime_error("mma16816 data type not supported"); @@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: " {$4, $5, $6, $7}," " {$8, $9}," " {$10, $11, $12, $13};", - "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); + "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { @@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ir::phi_node* phiB = dynamic_cast(B); auto register_lds2 = - [&](std::map, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) { + [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { if (k < 2 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); } else - vals[{n, k}] = val; + vals[{mn, k}] = val; }; auto load_a = [&](int m, int k, int inc, bool is_prefetch) { @@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? - return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); + if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && + A->get_type()->get_scalar_ty()->is_fp32_ty()) + return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); + throw std::runtime_error("dot has invalid operand type"); } void generator::visit_trans_inst(ir::trans_inst* trans) { diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index bae8fe828..b381d3cb0 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ // dot(a, b, c) + d -> dot(a, b, c + d) // d + dot(a, b, c) -> dot(a, b, c + d) auto add = dynamic_cast(value); - if(add && add->get_op() == ir::binary_op_t::FAdd) { + if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) { + bool is_int_dot = add->get_op() == ir::binary_op_t::Add; ir::value *lhs = add->get_operand(0); ir::value *rhs = add->get_operand(1); ir::dot_inst *lhs_dot = dynamic_cast(lhs); @@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *other = (dot == lhs) ? rhs : lhs; ir::value *acc = dot->get_operand(2); ir::splat_inst *splat = dynamic_cast(acc); - ir::constant_fp *_0 = nullptr; + ir::constant *_0 = nullptr; if(splat) - _0 = dynamic_cast(splat->get_operand(0)); - if(!(_0 && _0->get_value() == 0.0)) + _0 = dynamic_cast(splat->get_operand(0)); + if(!_0) return false; + if (auto *fp_0 = dynamic_cast(_0)) + if (fp_0->get_value() != 0.0) + return false; + if (auto *int_0 = dynamic_cast(_0)) + if (int_0->get_value() != 0) + return false; ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index f94d8f368..30b2a10f2 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) { if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() + && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) || + (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8) + && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8) && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ) ) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 910f23a1a..ada5459d6 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -46,6 +46,7 @@ #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Scalar.h" // begin AMD stuff #include "llvm/Support/FileSystem.h" @@ -121,9 +122,12 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; - // module->print(llvm::outs(), nullptr); pm.add(llvm::createVerifierPass()); + // pm.add(llvm::createDeadCodeEliminationPass()); + // pm.add(llvm::createEarlyCSEPass()); pm.run(*module); + // module->print(llvm::outs(), nullptr); + // create machine module->setTargetTriple(triple); std::string error; diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 69c76b5e5..e3168b375 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -726,7 +726,11 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask //===----------------------------------------------------------------------===// ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { - ir::value *_0 = builder->get_float32(0); + ir::value *_0 = nullptr; + if (lhs->get_type()->is_int_or_tileint_ty()) + _0 = builder->get_int32(0); + else + _0 = builder->get_float32(0); unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned N = rhs->get_type()->get_block_shapes()[1]; _0 = builder->create_splat(_0, {M, N}); diff --git a/python/setup.py b/python/setup.py index db22c14af..86e3e5160 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - # self.debug = True + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4d6e32aa6..3e35700f8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # --------------- -@pytest.mark.parametrize("epilogue, allow_tf32", - [(epilogue, allow_tf32) +@pytest.mark.parametrize("epilogue, allow_tf32, dtype", + [(epilogue, allow_tf32, dtype) for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] - for allow_tf32 in [True, False]]) -def test_dot(epilogue, allow_tf32, device='cuda'): + for allow_tf32 in [True, False] + for dtype in ['float32', 'int8'] + if not (allow_tf32 and (dtype == 'int8'))]) +def test_dot(epilogue, allow_tf32, dtype, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80: + if dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 80") + elif dtype == 'float32' and allow_tf32: + pytest.skip("Only test tf32 on devices with sm >= 80") + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, @@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'): # input M, N, K = 64, 64, 32 rs = RandomState(17) - x = numpy_random((M, K), dtype_str='float32', rs=rs) - y = numpy_random((K, N), dtype_str='float32', rs=rs) + x = numpy_random((M, K), dtype_str=dtype, rs=rs) + y = numpy_random((K, N), dtype_str=dtype, rs=rs) if allow_tf32: - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80: - pytest.skip("Only test tf32 on devices with sm >= 80") x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) # triton result - z = numpy_random((M, N), dtype_str='float32', rs=rs) + z = numpy_random((M, N), dtype_str=dtype, rs=rs) z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) @@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'): assert 'st.global.v4' in ptx if allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - else: + elif dtype == 'float32': assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + elif dtype == 'int8': + assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx def test_dot_without_load(): From e68d6a7776f5e7549da777c60b2cfa1c070646e2 Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 28 Jan 2022 01:59:54 +0800 Subject: [PATCH 055/215] [BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442) * Increase smem data-reuse for some layouts * tweak * Keep the original tiling logic for sm < 80 Co-authored-by: Philippe Tillet --- lib/codegen/analysis/layout.cc | 38 ++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 2206f5b6a..587234863 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -217,16 +217,36 @@ mma_layout::mma_layout(size_t num_warps, order_ = {0, 1}; /* warps per tile */ - // try to make things as square as possible to maximize data re-use wpt_ = {1, 1, 1}; - std::vector wpt_nm1; - do{ - wpt_nm1 = wpt_; - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); - }while(wpt_nm1 != wpt_); + // try to make warp-level tiles as square as possible to maximize data re-use + if (tgt->as_nvidia()->sm() < 80) { + std::vector wpt_nm1; + do{ + wpt_nm1 = wpt_; + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); + if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) + wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); + }while(wpt_nm1 != wpt_); + } else { + bool changed = false; + do { + changed = false; + if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) + break; + if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { + if (wpt_[0] < shape_[0] / spw_[0]) { + wpt_[0] *= 2; + changed = true; + } + } else { + if (wpt_[1] < shape_[1] / (spw_[1]*2)) { + wpt_[1] *= 2; + changed = true; + } + } + } while (changed); + } /* shape per block */ shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; From bd52e530a027f896ac8f400df3a8894f01a2e265 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Jan 2022 21:40:30 -0800 Subject: [PATCH 056/215] [OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445) --- python/triton/ops/blocksparse/matmul.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 48efe7ea3..ebcea8e1e 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device): header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() # create increments incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # pad by a factor 2*MAX_NUM_STAGES + # to accomodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) # create lut lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device) From bef76b142a859ad17ae3ff540d52387f92e93fe3 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 29 Jan 2022 18:29:29 -0800 Subject: [PATCH 057/215] [BACKEND] float division is now approximate by default (#446) --- include/triton/ir/dispatch.h | 1 + include/triton/ir/instructions.h | 7 +++++++ include/triton/ir/module.h | 3 +++ lib/codegen/selection/generator.cc | 7 +++++++ lib/ir/dispatch.cc | 12 ++++++++++++ lib/ir/instructions.cc | 2 +- python/src/triton.cc | 4 ++++ python/triton/code_gen.py | 2 ++ python/triton/language/core.py | 9 +++++++-- 9 files changed, 44 insertions(+), 3 deletions(-) diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index d8293d231..7dc8ab0ed 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -34,6 +34,7 @@ struct dispatch{ static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ca1416f48..5af077e8f 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -117,6 +117,7 @@ private: //===----------------------------------------------------------------------===// // binary_operator classes //===----------------------------------------------------------------------===// + class binary_operator: public instruction { public: typedef binary_op_t op_t; @@ -145,6 +146,10 @@ public: bool is_shl() const; bool is_shr() const; + // Approx + void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; } + bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; } + // Wraps void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } @@ -163,6 +168,8 @@ public: binary_op_t op_; bool has_no_unsigned_wrap_; bool has_no_signed_wrap_; + + bool fdiv_ieee_rnd_; }; diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index b350e3cc9..30881fd49 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -66,7 +66,10 @@ public: void set_continue_fn(std::function fn); // Getters const std::map& get_values() { return values_; } + const std::map& get_types() { return types_; } void set_values(const std::map& values) { values_ = values; } + void set_types(const std::map& types) { types_ = types; } + value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a55991475..986e8212e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -319,6 +319,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) { vals_[x][idx] = add(lhs, rhs); else if(op == ll::Mul) vals_[x][idx] = mul(lhs, rhs); + else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && + x->get_type()->get_scalar_ty()->is_fp32_ty()){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), + " div.full.f32 $0, $1, $2;", "=r,r,r", false); + vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); + + } else vals_[x][idx] = bin_op(op, lhs, rhs); } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index e3168b375..5a2d33a98 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -204,6 +204,18 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b throw_unreachable("floordiv"); } +ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) + throw semantic_error("both operands of fdiv must have floating point scalar type"); + binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); + ir::value* ret = builder->create_fdiv(input, other); + if(ir::binary_operator* binop = dynamic_cast(ret)) + binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); + return ret; +} + ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 6e416a43e..ac5a17289 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -134,7 +134,7 @@ bool binary_operator::is_int_add_sub() const { binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) - : instruction(ty, INST_BINOP, 2, name, next), op_(op){ + : instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){ set_operand(0, lhs); set_operand(1, rhs); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 8ceb14200..77edf791f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) { m.def("mul", &ir::dispatch::mul, ret::reference); m.def("truediv", &ir::dispatch::truediv, ret::reference); m.def("floordiv", &ir::dispatch::floordiv, ret::reference); + m.def("fdiv", &ir::dispatch::fdiv, ret::reference); m.def("mod", &ir::dispatch::mod, ret::reference); m.def("and_", &ir::dispatch::and_, ret::reference); m.def("or_", &ir::dispatch::or_, ret::reference); @@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) { .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) .def_property_readonly("context", &ir::type::get_context, ret::reference); @@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) { .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) .def("get_values", &ir::module::get_values, ret::reference) .def("set_values", &ir::module::set_values) + .def("get_types", &ir::module::get_types, ret::reference) + .def("set_types", &ir::module::set_types) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 347635e32..48057b770 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -986,12 +986,14 @@ class JITFunction: gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() + types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) + generator.module.set_types(types) return ret except Exception as e: node = generator.last_node diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df3b1f4cf..425e12e01 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None): return frontend.umulhi(x, y, _builder) +@builtin +def fdiv(x, y, ieee_rounding=False, _builder=None): + return frontend.fdiv(x, y, ieee_rounding, _builder) + + def _add_math_1arg_docstr(name): def _decorator(func): @@ -869,11 +874,11 @@ def sigmoid(x): @triton.jit @_add_math_1arg_docstr("softmax") -def softmax(x): +def softmax(x, ieee_rounding=False): z = x - triton.language.max(x, 0) num = triton.language.exp(z) den = triton.language.sum(num, 0) - return num / den + return fdiv(num, den, ieee_rounding) @triton.jit From 807d8a194528e8a67975a49f9a38fc4f7420bcd5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 30 Jan 2022 20:21:20 -0800 Subject: [PATCH 058/215] [ALL] Merge master (#447) --- .gitignore | 3 + .gitmodules | 3 + CMakeLists.txt | 51 +++++++- deps/dlfcn-win32 | 1 + include/triton/tools/sys/exec.hpp | 8 ++ include/triton/tools/sys/getenv.hpp | 11 +- lib/codegen/selection/generator.cc | 20 +-- lib/driver/dispatch.cc | 9 ++ lib/driver/llvm.cc | 14 +- python/setup.py | 4 +- python/src/triton.cc | 1 + python/test/unit/runtime/test_comm.py | 4 +- python/triton/code_gen.py | 23 ++-- python/triton/language/core.py | 177 +++++++++++++------------- 14 files changed, 199 insertions(+), 130 deletions(-) create mode 100644 .gitmodules create mode 160000 deps/dlfcn-win32 diff --git a/.gitignore b/.gitignore index c10863ae9..b32df68cc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ +build/ + __pycache__ .pytest_cache python/build/ python/triton.egg-info/ +python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..2754cffc4 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deps/dlfcn-win32"] + path = deps/dlfcn-win32 + url = https://github.com/dlfcn-win32/dlfcn-win32.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 8fb73e678..f44c35aa7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 3.6) include(ExternalProject) +set(CMAKE_CXX_STANDARD 17) + if(NOT TRITON_LLVM_BUILD_DIR) set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR}) endif() @@ -8,7 +10,9 @@ endif() project(triton) include(CTest) -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +if(NOT WIN32) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +endif() # Options option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON) @@ -20,10 +24,19 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -find_library(TERMINFO_LIBRARY tinfo) +if(NOT WIN32) + find_library(TERMINFO_LIBRARY tinfo) +endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +if(WIN32) + SET(BUILD_SHARED_LIBS OFF) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src) + add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32) +endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") @@ -31,7 +44,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17") # LLVM ########## if("${LLVM_LIBRARY_DIR}" STREQUAL "") - find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu") + if(WIN32) + find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu) + + include_directories(${LLVM_INCLUDE_DIRS}) + separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) + add_definitions(${LLVM_DEFINITIONS_LIST}) + + llvm_map_components_to_libnames(LLVM_LIBRARIES support core + NVPTXInfo nvptxcodegen + AMDGPUInfo AMDGPUcodegen + ) + else() + find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu") + endif() message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") if(APPLE) set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14") @@ -108,12 +134,25 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) -add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) +if (WIN32 AND BUILD_PYTHON_MODULE) + find_package(Python3 REQUIRED COMPONENTS Development) + Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") +else() + add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) +endif() + target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) -target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY}) + +if(WIN32) + target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32 +else() + target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY}) +endif() -if(BUILD_PYTHON_MODULE) +if(BUILD_PYTHON_MODULE AND NOT WIN32) set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") # Check if the platform is MacOS if(APPLE) diff --git a/deps/dlfcn-win32 b/deps/dlfcn-win32 new file mode 160000 index 000000000..522c301ec --- /dev/null +++ b/deps/dlfcn-win32 @@ -0,0 +1 @@ +Subproject commit 522c301ec366e9b42205ae21617780d37cc0e9f0 diff --git a/include/triton/tools/sys/exec.hpp b/include/triton/tools/sys/exec.hpp index 243f0f482..5b664553e 100644 --- a/include/triton/tools/sys/exec.hpp +++ b/include/triton/tools/sys/exec.hpp @@ -13,6 +13,14 @@ namespace tools { +#ifdef _WIN32 +#define popen _popen +#define pclose _pclose +#endif + +#ifndef WEXITSTATUS +#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255) +#endif int exec(const std::string& cmd, std::string& result) { char buffer[128]; diff --git a/include/triton/tools/sys/getenv.hpp b/include/triton/tools/sys/getenv.hpp index 0319d8868..755a84a66 100755 --- a/include/triton/tools/sys/getenv.hpp +++ b/include/triton/tools/sys/getenv.hpp @@ -33,19 +33,10 @@ namespace tools inline std::string getenv(const char * name) { - #ifdef _MSC_VER - char* cache_path = 0; - std::size_t sz = 0; - _dupenv_s(&cache_path, &sz, name); - #else - const char * cstr = std::getenv(name); - #endif + const char * cstr = std::getenv(name); if(!cstr) return ""; std::string result(cstr); - #ifdef _MSC_VER - free(cache_path); - #endif return result; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 986e8212e..26b6b342a 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -449,18 +449,18 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); - packed_in = insert_elt(packed_in, in0, (int)0); - packed_in = insert_elt(packed_in, in1, (int)1); - packed_in = insert_elt(packed_in, in2, (int)2); - packed_in = insert_elt(packed_in, in3, (int)3); + packed_in = insert_elt(packed_in, in0, (uint64_t)0); + packed_in = insert_elt(packed_in, in1, (uint64_t)1); + packed_in = insert_elt(packed_in, in2, (uint64_t)2); + packed_in = insert_elt(packed_in, in3, (uint64_t)3); Value *in = bit_cast(packed_in, i32_ty); Value *ret = call(ptx, {in}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); - Value *ret0 = extract_elt(packed_ret0, (int)0); - Value *ret1 = extract_elt(packed_ret0, (int)1); - Value *ret2 = extract_elt(packed_ret1, (int)0); - Value *ret3 = extract_elt(packed_ret1, (int)1); + Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); + Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); + Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); + Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); return std::make_tuple(ret0, ret1, ret2, ret3); } @@ -717,11 +717,11 @@ void generator::visit_load_inst(ir::load_inst* x){ // --- // finally call inline ASM // --- - InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); + InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); - Value *_ret = call(_asm, args); + Value *_ret = call(inlineAsm, args); // --- // extract and store return values // --- diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 4059ac235..9e2aca432 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -91,9 +91,13 @@ void* dispatch::fname ## _; bool dispatch::cuinit(){ if(cuda_==nullptr){ + #ifdef _WIN32 + cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY); + #else cuda_ = dlopen("libcuda.so", RTLD_LAZY); if(!cuda_) cuda_ = dlopen("libcuda.so.1", RTLD_LAZY); + #endif if(!cuda_) throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH."); } @@ -176,8 +180,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) * NVML * ------------------- */ bool dispatch::nvmlinit(){ + #ifdef _WIN32 + if(nvml_==nullptr) + nvml_ = dlopen("nvml.dll", RTLD_LAZY); + #else if(nvml_==nullptr) nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY); + #endif nvmlReturn_t (*fptr)(); nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2"); *reinterpret_cast(&fptr) = nvmlInit_v2_; diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index ada5459d6..f25b763ca 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -20,7 +20,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include -#include +#if __has_include() + #include +#endif #include #include #include "triton/driver/llvm.h" @@ -185,8 +187,10 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; char _flog[L_tmpnam]; - std::string fsrc = std::tmpnam(_fsrc); - std::string flog = std::tmpnam(_flog); + std::tmpnam(_fsrc); + std::tmpnam(_flog); + std::string fsrc = _fsrc; + std::string flog = _flog; std::string fbin = fsrc + ".o"; const char* _fbin = fbin.c_str(); std::ofstream ofs(fsrc); @@ -367,8 +371,8 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) { hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer, hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; - unsigned int errbufsize = 8192; - unsigned int logbufsize = 8192; + const unsigned int errbufsize = 8192; + const unsigned int logbufsize = 8192; char _err[errbufsize]; char _log[logbufsize]; void* optval[] = {(void*)(uintptr_t)errbufsize, diff --git a/python/setup.py b/python/setup.py index 86e3e5160..6a04a4e42 100644 --- a/python/setup.py +++ b/python/setup.py @@ -23,6 +23,8 @@ def get_llvm(): paths = [p for p in paths if p is not None] if paths: return '', '' + if platform.system() == "Windows": + return '', '' # download if nothing is installed name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' dir = '/tmp' @@ -104,7 +106,7 @@ class CMakeBuild(build_ext): build_args = ["--config", cfg] if platform.system() == "Windows": - cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)] + cmake_args += ["-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)] if sys.maxsize > 2**32: cmake_args += ["-A", "x64"] build_args += ["--", "/m"] diff --git a/python/src/triton.cc b/python/src/triton.cc index 77edf791f..e9c5e637c 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -15,6 +15,7 @@ #include #include "Python.h" #include +#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py index 6d0658f3b..ae3fb69d7 100644 --- a/python/test/unit/runtime/test_comm.py +++ b/python/test/unit/runtime/test_comm.py @@ -25,13 +25,13 @@ def get_p2p_matrix(): def get_p2p_devices(): matrix = get_p2p_matrix() idx = np.where(matrix == "OK") - return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] def get_non_p2p_devices(): matrix = get_p2p_matrix() idx = np.where(matrix == "NS") - return f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}" + return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] p2p_devices = get_p2p_devices() diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 48057b770..dc2b375b8 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -358,9 +358,6 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) - def visit_Str(self, node): - return ast.literal_eval(node) - def visit_Subscript(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) @@ -441,9 +438,6 @@ class CodeGenerator(ast.NodeVisitor): def visit_Index(self, node): return self.visit(node.value) - def visit_NameConstant(self, node): - return node.value - def visit_keyword(self, node): return {node.arg: self.visit(node.value)} @@ -460,10 +454,23 @@ class CodeGenerator(ast.NodeVisitor): if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) + if fn in self.builtins.values(): + args = [arg.value if isinstance(arg, triton.language.constexpr) else arg + for arg in args] return fn(*args, **kws) - def visit_Num(self, node): - return triton.language.constexpr(node.n) + def visit_Constant(self, node): + return triton.language.constexpr(node.value) + + if sys.version_info < (3, 8): + def visit_NameConstant(self, node): + return triton.language.constexpr(node.value) + + def visit_Num(self, node): + return triton.language.constexpr(node.n) + + def visit_Str(self, node): + return triton.language.constexpr(ast.literal_eval(node)) def visit_Attribute(self, node): lhs = self.visit(node.value) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 425e12e01..f4188f1c7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -130,6 +130,94 @@ float64 = dtype(ir.type.get_fp64) # pointer types pi32_t = pointer_dtype(int32) +# ----------------------- +# constexpr +# ----------------------- + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + # + def __add__(self, other): + return self.value + other.value + + def __radd__(self, other): + return other.value + self.value + + def __sub__(self, other): + return self.value - other.value + + def __rsub__(self, other): + return other.value - self.value + + def __mul__(self, other): + return self.value * other.value + + def __rmul__(self, other): + return other.value * self.value + + def __truediv__(self, other): + return self.value / other.value + + def __rtruediv__(self, other): + return other.value / self.value + + def __floordiv__(self, other): + return self.value // other.value + + def __rfloordiv__(self, other): + return other.value // self.value + + # + + def __gt__(self, other): + return self.value > other.value + + def __rgt__(self, other): + return other.value > self.value + + def __ge__(self, other): + return self.value >= other.value + + def __rge__(self, other): + return other.value >= self.value + + def __lt__(self, other): + return self.value < other.value + + def __rlt__(self, other): + return other.value < self.value + + def __le__(self, other): + return self.value <= other.value + + def __rle__(self, other): + return other.value <= self.value + + def __eq__(self, other): + return self.value == other.value + + def __ne__(self, other): + return self.value != other.value + + def __bool__(self): + return bool(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + class block: @staticmethod @@ -296,7 +384,7 @@ class block: dst_shape = [] curr = 0 for sl in slices: - if sl is None: + if isinstance(sl, constexpr) and sl.value is None: dst_shape.append(1) elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) @@ -312,93 +400,6 @@ class block: return frontend.cast(self, dtype, _builder) -# ----------------------- -# constexpr -# ----------------------- - -class constexpr: - """ - This class is used to store a value that is known at compile-time. - """ - - def __init__(self, value): - if isinstance(value, constexpr): - self.value = value.value - else: - self.value = value - - def __repr__(self) -> str: - return f"constexpr[{self.value}]" - - # - def __add__(self, other): - return self.value + other.value - - def __radd__(self, other): - return other.value + self.value - - def __sub__(self, other): - return self.value - other.value - - def __rsub__(self, other): - return other.value - self.value - - def __mul__(self, other): - return self.value * other.value - - def __rmul__(self, other): - return other.value * self.value - - def __truediv__(self, other): - return self.value / other.value - - def __rtruediv__(self, other): - return other.value / self.value - - def __floordiv__(self, other): - return self.value // other.value - - def __rfloordiv__(self, other): - return other.value // self.value - - # - - def __gt__(self, other): - return self.value > other.value - - def __rgt__(self, other): - return other.value > self.value - - def __ge__(self, other): - return self.value >= other.value - - def __rge__(self, other): - return other.value >= self.value - - def __lt__(self, other): - return self.value < other.value - - def __rlt__(self, other): - return other.value < self.value - - def __le__(self, other): - return self.value <= other.value - - def __rle__(self, other): - return other.value <= self.value - - def __eq__(self, other): - return self.value == other.value - - def __ne__(self, other): - return self.value != other.value - - def __bool__(self): - return bool(self.value) - - def __call__(self, *args, **kwds): - return self.value(*args, **kwds) - # ----------------------- # SPMD Programming Model # ----------------------- From b0d6e2f3228e8ccf3e64f2a3c104b849efdd6989 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 30 Jan 2022 20:27:44 -0800 Subject: [PATCH 059/215] [STYLE] run autopep --- python/triton/code_gen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 6e932859e..dc2b375b8 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -472,7 +472,6 @@ class CodeGenerator(ast.NodeVisitor): def visit_Str(self, node): return triton.language.constexpr(ast.literal_eval(node)) - def visit_Attribute(self, node): lhs = self.visit(node.value) return getattr(lhs, node.attr) From 137bb67fad31568463b1be85ee12706a35650855 Mon Sep 17 00:00:00 2001 From: TC <93944281+tomconerlyanth@users.noreply.github.com> Date: Wed, 2 Feb 2022 20:42:09 -0800 Subject: [PATCH 060/215] [LANG] Add fp16 to fp8 conversion (#444) --- lib/codegen/selection/generator.cc | 79 +++++++++++++++++++----- python/test/unit/language/test_core.py | 84 ++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 15 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 26b6b342a..d2ebce1c6 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -414,13 +414,13 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) { std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ - auto cvt = [this](Value *v){ - if(ConstantFP* ci = dyn_cast(v)) - if(ci->getValue().convertToFloat() == 0) - return builder_->getInt8(0); - throw std::runtime_error("unsupported cast"); - }; - return std::make_tuple(cvt(in0), cvt(in1), cvt(in2), cvt(in3)); + in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty); + in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty); + in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty); + in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty); + Value *ret0, *ret1, *ret2, *ret3; + std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3); + return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ @@ -439,14 +439,14 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), "{" ".reg .b32 a<2>, b<2>; \n\t" - "prmt.b32 a0, 0, $2, 0x5140; \n\t" - "prmt.b32 a1, 0, $2, 0x7362; \n\t" - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" - "shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion - "shr.b32 b1, b1, 1; \n\t" - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" + "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0 + "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) + "shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion) + "shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position) + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign) + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign) "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); packed_in = insert_elt(packed_in, in0, (uint64_t)0); @@ -464,6 +464,51 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 return std::make_tuple(ret0, ret1, ret2, ret3); } +std::tuple generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { + /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa) + * fp8 bit representation is seeeemmm + * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16. + * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16. + * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous. + * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's + * one more than the number of mantissa bits in fp8). + * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff) + * + * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the + * other. To avoid this we zero out the most significant exponent bit. If that bit is set then + * the value isn't representable in float8 anyway so we assume it's never set (and give garbage + * output if it is). If we were willing to assume the most significant exponent was never set + * we could save the first two lop3.b32 instructions below. + */ + InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), + "{" + ".reg .b32 a<2>, b<2>; \n\t" + "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1 + "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1 + "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff) + "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff) + "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080 + "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0 + "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1 + "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02 + "}", "=r,r,r", false); + Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2)); + Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2)); + packed_in0 = insert_elt(packed_in0, in0, (int)0); + packed_in0 = insert_elt(packed_in0, in1, (int)1); + packed_in1 = insert_elt(packed_in1, in2, (int)0); + packed_in1 = insert_elt(packed_in1, in3, (int)1); + Value *in_arg0 = bit_cast(packed_in0, i32_ty); + Value *in_arg1 = bit_cast(packed_in1, i32_ty); + Value *ret = call(ptx, {in_arg0, in_arg1}); + Value *ret0 = extract_elt(ret, (int)0); + Value *ret1 = extract_elt(ret, (int)1); + Value *ret2 = extract_elt(ret, (int)2); + Value *ret3 = extract_elt(ret, (int)3); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + Value* generator::bf16_to_fp32(Value *in0){ if (tgt_->as_nvidia()->sm() >= 80) { InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), @@ -508,8 +553,12 @@ void generator::visit_cast_inst(ir::cast_inst* x) { auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) return fp32x4_to_fp8x4(a, b, c, d); + if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty()) + return fp16x4_to_fp8x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) return fp8x4_to_fp16x4(a, b, c, d); + if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) + return fp8x4_to_fp32x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); }; for(size_t i = 0; i < x_idxs.size(); i+=4){ diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3e35700f8..a49b47585 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -565,6 +565,90 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z_ref = x.astype(getattr(np, dtype_z)) assert to_numpy(z_tri) == z_ref + +def test_f8_f16_roundtrip(): + """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') + f8 = triton.reinterpret(f8_tensor, tl.float8) + n_elements = f8_tensor.numel() + f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024) + + f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + print(f16.dtype, f8_output.dtype) + copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) + + assert torch.all(f8_tensor == f8_output_tensor) + + +def test_f16_to_f8_rounding(): + """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute + error is the minimum over all float8. + + Or the same explanation a bit mathier: + for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view + f16_input_np = ( + np.array( + range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16, + ) + .view(np.float16) + ) + f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda') + n_elements = f16_input.numel() + f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024) + + f16_output = torch.empty_like(f16_input, dtype=torch.float16) + copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024) + + abs_error = torch.abs(f16_input - f16_output) + + all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda') + all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8) + all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16) + copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024) + + all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[ + torch.isfinite(all_f8_vals_in_f16) + ] + + min_error = torch.min( + torch.abs( + f16_input.reshape((-1, 1)) + - all_finite_f8_vals_in_f16.reshape((1, -1)) + ), + dim=1, + )[0] + # 1.9375 is float8 max + mismatch = torch.logical_and( + abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375) + ) + assert torch.all( + torch.logical_not(mismatch) + ), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}" + + # --------------- # test reduce # --------------- From 69ff52ea1f98336e89f6cf27694196309e709aa9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 5 Feb 2022 21:37:23 -0800 Subject: [PATCH 061/215] [CODEGEN] removed buggy (and mostly useless) optimization in peephole pass (#449) --- lib/codegen/transform/peephole.cc | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index b381d3cb0..e30ab9b35 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -227,22 +227,22 @@ bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){ ir::instruction* op = dynamic_cast(cvt->get_operand(0)); if(!op) return false; - // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) - if(op->get_id() == ir::INST_BINOP){ - for(size_t i = 0; i < op->get_num_operands(); i++){ - ir::value* arg_i = op->get_operand(i); - builder.set_insert_point(op); - // create new layout transform - ir::instruction* new_arg_i = cvt->clone(); - layouts_->copy(new_arg_i, op); - builder.insert(new_arg_i); - // set the right args - new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); - op->replace_uses_of_with(arg_i, new_arg_i); - } - cvt->replace_all_uses_with(op); - return true; - } +// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) +// if(op->get_id() == ir::INST_BINOP){ +// for(size_t i = 0; i < op->get_num_operands(); i++){ +// ir::value* arg_i = op->get_operand(i); +// builder.set_insert_point(op); +// // create new layout transform +// ir::instruction* new_arg_i = cvt->clone(); +// layouts_->copy(new_arg_i, op); +// builder.insert(new_arg_i); +// // set the right args +// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); +// op->replace_uses_of_with(arg_i, new_arg_i); +// } +// cvt->replace_all_uses_with(op); +// return true; +// } auto cvt_op = dynamic_cast(op); if(!cvt_op) return false; From 5a8a544d101e632500f05c2c9f87d56a0d376adc Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 6 Feb 2022 18:00:45 -0800 Subject: [PATCH 062/215] [OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450) * dds layout now internally re-uses dsd code path for increased code * at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks. * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros * blocksparse softmax now no longer modifies any data in-place * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention. * unit tests now test backward pass --- .../test/unit/operators/test_blocksparse.py | 161 ++++---- python/triton/ops/blocksparse/matmul.py | 141 ++----- python/triton/ops/blocksparse/softmax.py | 357 +++++++++--------- python/triton/testing.py | 13 + 4 files changed, 311 insertions(+), 361 deletions(-) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index ed569c04d..9e0c72de9 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -10,77 +10,108 @@ import triton @pytest.mark.parametrize("BLOCK", [16, 32, 64]) @pytest.mark.parametrize("DTYPE", [torch.float16]) def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): - # set seed - torch.random.manual_seed(0) + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK) # create inputs - a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda") - b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda") + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) shape = { "sdd": (M, N), - "dsd": (a.shape[2], a.shape[3]), - "dds": (b.shape[2], b.shape[3]), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1) + b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1) + dc_ref, dc_tri = triton.testing.make_pair(c_shape) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.retain_grad() + b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, + b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.retain_grad() + b_tri.retain_grad() op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") - ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a - rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b - rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) - # torch result - ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a - tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b - ta = ta.transpose(2, 3) if TRANS_A else ta - tb = tb.transpose(2, 3) if TRANS_B else tb - tc = torch.matmul(ta, tb) - tc = triton.testing.mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc - tc = triton.testing.sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc + c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest) + triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest) + da_tri = a_tri.grad + db_tri = b_tri.grad # compare - triton.testing.assert_almost_equal(rc, tc) + triton.testing.assert_almost_equal(c_ref, c_tri) + triton.testing.assert_almost_equal(da_ref, da_tri) + triton.testing.assert_almost_equal(db_ref, db_tri) -@pytest.mark.parametrize("BLOCK", [16, 32, 64]) -@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) -@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) -def test_softmax(BLOCK, WIDTH, DTYPE): - is_causal = True +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): # set seed torch.random.manual_seed(0) - Z, H, M, N = 1, 1, WIDTH, WIDTH - scale = 0.4 - # create inputs - layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) - x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") - at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda") + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout # make sure each row has at least one non-zero element - torch.diagonal(layout)[:] = 1 - torch.diagonal(at_mask)[:] = 1 - kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda") - kp_mask[:] = 0 - kp_mask[kp_mask == 1.0] = float("-inf") - # triton result - op = triton.ops.blocksparse.softmax(layout, BLOCK) - tx = triton.testing.sparsify_tensor(x, layout, BLOCK) - ty = op( - tx, - scale=scale, - key_padding_mask=kp_mask, - key_padding_mask_mode="add", - attn_mask=at_mask.to(DTYPE), - attn_mask_mode="mul", - is_causal=is_causal, - ) - # torch result - rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) - # broadcast at_mask to the same shape as rx + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = triton.testing.make_pair(a_shape) + dout_ref, dout_tri = triton.testing.make_pair(a_shape) + # compute [torch] + a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device="cuda") if is_causal: at_mask = torch.tril(at_mask) - M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float("-inf") - # rx += kp_mask[:, None, None, :] - ry = torch.softmax(rx * scale, -1) - ry = triton.testing.sparsify_tensor(ry, layout, BLOCK) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK) + da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad # compare - triton.testing.assert_almost_equal(ry, ty) + triton.testing.assert_almost_equal(out_tri, out_ref) + triton.testing.assert_almost_equal(da_tri, da_ref) @pytest.mark.parametrize("block", [16, 32, 64]) @@ -99,14 +130,6 @@ def test_attention_fwd_bwd( qkvs = [ torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) ] - attn_mask = torch.tril( - torch.ones( - [n_ctx, n_ctx], - device="cuda", - dtype=dtype, - ), - diagonal=0, - ) # Triton: n_blocks = n_ctx // block @@ -115,7 +138,7 @@ def test_attention_fwd_bwd( query.retain_grad() key.retain_grad() value.retain_grad() - attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale) + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) # ad hoc loss loss = (attn_out ** 2).mean() loss.backward() @@ -123,6 +146,8 @@ def test_attention_fwd_bwd( # Torch version: torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) torch_q.retain_grad() torch_k.retain_grad() @@ -147,7 +172,6 @@ def test_attention_fwd_bwd( def triton_attention( layout, block: int, - attn_mask: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -155,12 +179,9 @@ def triton_attention( ): sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) - sparse_softmax = triton.ops.blocksparse.softmax( - layout, - block, - ) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) w = sparse_dot_sdd_nt(query, key) - w = sparse_softmax(w, scale=scale, attn_mask=attn_mask, attn_mask_mode="mul") + w = sparse_softmax(w, scale=scale, is_causal=True) a = sparse_dot_dsd_nn(w, value) return a diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index ebcea8e1e..0fa1a5878 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -242,8 +242,31 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ sizes = torch.sum(layout, 2 if trans else 1) - head_id, col_id = sizes.nonzero(as_tuple=True) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) sizes = sizes.flatten() segments = sizes * step # pointer increments @@ -258,13 +281,6 @@ def dsd_lut(layout, block, step, trans, device): # ------------------------------- # dense input pointer increments # ------------------------------- - # given a list of the indices for the first element of each non-zero block. - # For example, for the indices - # [32, 80, 128, 256, 288] - # we would generate the increments - # [32, 48, 48, 128, 32] - # ^ - # index of first element # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) # that is smaller than the block size, so we need to do a bit of extra work # to handle this case @@ -324,116 +340,11 @@ def dsd_lut(layout, block, step, trans, device): # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- - - -@triton.jit -def _dds_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ka, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_mc, stride_nc, - DS0, DS1, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr, -): - # ------------ # - # - Prologue - # - # ------------ # - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) - pid_z = tl.program_id(2) - header = lut + pid_n * 4 - offset = tl.load(header + 0) - AS1 = tl.load(header + 1) - column = tl.load(header + 2) - off_h = tl.load(header + 3) - pinc = lut + offset - # initialize pointers to A (dense) - offs_am = pid_m * TILE_M + tl.arange(0, TILE_M) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M) - start_ak = tl.load(pinc) - start_ak = tl.multiple_of(start_ak, 8) - offs_ak = start_ak + tl.arange(0, TILE_K) - ptrs_a = A + pid_z * stride_za \ - + off_h * stride_ha \ - + offs_am[:, None] * stride_ma \ - + offs_ak[None, :] * stride_ka - # initialize pointers to B (sparse) - block_id = tl.load(pinc + 1) - block_id = tl.multiple_of(block_id, 8) - offs_bn = tl.arange(0, TILE_N) - offs_bk = tl.arange(0, TILE_K) - ptrs_b = B + pid_z * stride_zb \ - + block_id * stride_hb \ - + offs_bn[None, :] * stride_bn \ - + offs_bk[:, None] * stride_bk - # ---------------- # - # Inner Loop # - # ---------------- # - acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) - for k in range(AS1, 0, -TILE_K): - a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0) - b = tl.load(ptrs_b, mask=True) - acc += tl.dot(a, b) - pinc += 2 - inc_a = tl.load(pinc) - inc_b = tl.load(pinc + 1) - inc_a = tl.multiple_of(inc_a, 8) - inc_b = tl.multiple_of(inc_b, 8) - inc_a = inc_a * stride_ka - ptrs_a += inc_a - ptrs_b += inc_b - # ---------------- # - # Epilogue # - # ---------------- # - c = acc.to(C.dtype.element_ty) - # initialize pointers to C (dense) - offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M) - offs_cn = column * TILE_N + tl.arange(0, TILE_N) - ptrs_c = C + off_h * stride_hc \ - + pid_z * stride_zc \ - + offs_cm[:, None] * stride_mc \ - + offs_cn[None, :] * stride_nc - # write back - tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0) +# AB = (B^T A^T)^T def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): - if a.stride(2) != 1 and a.stride(3) != 1: - a = a.contiguous() - if b.stride(2) != 1 and b.stride(3) != 1: - b = b.contiguous() - # shapes / dtypes - AS0 = a.size(0) - AS1 = a.size(1) - AS2 = a.size(3 if trans_a else 2) - BS2 = block * spdims[1 if trans_b else 2] - dtype = a.dtype - # output - CS0 = AS0 - CS1 = AS1 - CS2 = BS2 if trans_c else AS2 - CS3 = AS2 if trans_c else BS2 - if out is None: - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - else: - assert out.shape == (CS0, CS1, CS2, CS3) - c = out - TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] - grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0] - _dds_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), - AS2, BS2, lut, - TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4, - num_warps=4, GROUP_SIZE_M=4, - ) - return c + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) ############## # MAIN API # diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 6ac76dcc4..bb915be13 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -5,230 +5,235 @@ import triton.language as tl def num_warps(n): - if n < 512: + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: return 4 - if n < 2048: + if n <= 4096: return 8 return 16 -@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) -@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])}) @triton.jit -def _forward( - X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr, - KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr, +def _blocksparse_softmax_fwd( + Out, A, stride_xz, LUT, + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, ): - pidhm = tl.program_id(0) - pidz = tl.program_id(1) + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) # create index ranges - rxm = pidhm % BLOCK - rbm = pidhm // BLOCK - rxn = tl.arange(0, TN) % BLOCK - rbn = tl.arange(0, TN) // BLOCK + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE # extract information from LUT - header = LUT + rbm * 2 + header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) - check = rbn < size - rbmn = tl.where(check, rbn, size - 1) - # block id and column id - blockid = tl.load(LUT + offset + rbmn * 4 + 0) - columnid = tl.load(LUT + offset + rbmn * 4 + 1) - rowid = tl.load(LUT + offset + rbmn * 4 + 2) - headid = tl.load(LUT + offset + rbmn * 4 + 3) - # pointers to X - px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - x = tl.load(px, mask=check, other=-float('inf')) - x = x.to(tl.float32) - # apply scale - if APPLY_SCALE: - x = x * scale - # apply RPE - if APPLY_RPE: - prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn - rpe = tl.load(prpe, mask=check, other=0) - x = x + rpe - # apply key-padding mask - if APPLY_KP_MASK: - pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn - kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) - if KP_MASK_MUL: - kp_m = tl.where(kp_m == 0, -float('inf'), 0.) - x = x + kp_m - # apply attention mask - if APPLY_ATTN_MASK: - pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn - attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) - if ATTN_MASK_MUL: - attn_m = tl.where(attn_m == 0, -float('inf'), 0.) - x = x + attn_m + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) # apply causal mask - is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm - x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.) + out = tl.where((ns > m) & is_causal, -float("inf"), out) # computation - x = tl.softmax(x) - tl.store(px, x, mask=check) + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) -@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])}) -@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']}) @triton.jit -def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr): - pidhm = tl.program_id(0) - pidz = tl.program_id(1) +def _blocksparse_softmax_bwd( + DA, stride_zdx, + DOut, stride_zdout, + Out, stride_zout, + scale, + LUT, + DR, extent, stride_zr, stride_hr, stride_er, + is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, +): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) # create index ranges - rxm = pidhm % BLOCK - rbm = pidhm // BLOCK - rxn = tl.arange(0, TN) % BLOCK - rbn = tl.arange(0, TN) // BLOCK - # extract information from look-up table - header = LUT + rbm * 2 + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 size = tl.load(header + 0) offset = tl.load(header + 1) - # bounds checking on lut - check = rbn < size - rbmn = tl.where(check, rbn, size - 1) - # initialize pointers to block-sparse input - blockid = tl.load(LUT + offset + rbmn * 4) - X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - # compute fused softmax backward - x = tl.load(X, mask=check, other=0) - dx = tl.load(DX, mask=check, other=0) - x = x.to(tl.float32) - dx = dx.to(tl.float32) - y = x * (dx - tl.sum(x * dx, 0)) * scale - tl.store(DX, y, mask=check) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + da = a * (dout - tl.sum(a * dout, 0)) + da = tl.where((ns > m) & is_causal, 0., da) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) class _softmax(torch.autograd.Function): @staticmethod def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() # sizes along rows - sizes = layout.sum(-1).view(-1) + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) # block indices - layout_sum = sizes.sum() - idx = torch.arange(layout_sum, device=layout.device) - layout_nonzero = layout.nonzero(as_tuple=False) - head = layout_nonzero[:, 0] - rows = layout_nonzero[:, 1] - columns = layout_nonzero[:, 2] - core = torch.stack((idx, columns, rows, head), dim=1).view(-1) - # construct look-up table - offsets = offsets * 4 + 2 * sizes.numel() + columns = layout.nonzero(as_tuple=False)[:, 2] header = torch.stack((sizes, offsets), dim=1).view(-1) - lut = torch.cat((header, core)).type(torch.int32).to(device) - return lut, int(sizes.max()) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) @staticmethod def forward( - ctx, x, scale, rpe, - key_padding_mask, attn_mask, - kp_mask_mode, attn_mask_mode, - is_causal, - spdims, block, lut, maxlut + ctx, a, scale, rel_logits, is_causal, + spdims, block, lut, maxlut, is_dense ): - apply_scale = False if scale == 1.0 else True - # handle None rpe - if rpe is None: - apply_rpe = False - stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0 - rpe = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_rpe = True - stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2) - # handle None key_padding_mask - if key_padding_mask is None: - apply_kp_mask = False - stride_zkpm = 0 - key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_kp_mask = True - stride_zkpm = key_padding_mask.stride(0) - # handle None attention_mask - if attn_mask is None: - apply_attn_mask = False - stride_zattnm = 0 - attn_mask = torch.empty(0, dtype=x.dtype, device=x.device) - else: - apply_attn_mask = True - stride_zattnm = attn_mask.stride(0) - # run kernel - M = x.shape[0] - grid = [spdims[0] * spdims[1] * block, M] - _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0), - stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, - BLOCK=block, - APPLY_SCALE=apply_scale, - APPLY_RPE=apply_rpe, - APPLY_KP_MASK=apply_kp_mask, - APPLY_ATTN_MASK=apply_attn_mask, - KP_MASK_MUL=(kp_mask_mode == 'mul'), - ATTN_MASK_MUL=(attn_mask_mode == 'mul')) + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn + scale, + is_causal, + BLOCK_SIZE=block, + ROW_SIZE=triton.next_power_of_2(maxlut), + IS_DENSE=is_dense, + num_warps=num_warps(maxlut) + ) # save to context - ctx.mark_dirty(x) - ctx.save_for_backward(x, lut) + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) ctx.spdims = spdims ctx.block = block ctx.maxlut = maxlut ctx.scale = scale - ctx.apply_scale = apply_scale - ctx.apply_rpe = apply_rpe - ctx.apply_kp_mask = apply_kp_mask - ctx.apply_attn_mask = apply_attn_mask - ctx.kp_mask_mode = kp_mask_mode - ctx.attn_mask_mode = attn_mask_mode - return x + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out @staticmethod - def backward(ctx, dx): + def backward(ctx, dout): # retrieve from context - x, lut = ctx.saved_tensors + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) # run kernel - M = x.shape[0] - grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] - _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block) - return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), + dout, dout.stride(0), + out, out.stride(0), + ctx.scale, + lut, + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], + ctx.is_causal, + BLOCK_SIZE=ctx.block, + ROW_SIZE=triton.next_power_of_2(ctx.maxlut), + IS_DENSE=ctx.is_dense, + num_warps=num_warps(ctx.maxlut) + ) + return (da, None, None, dr, None, + None, None, None, None, None, + None, + None, None, None, + None, + None, None, None + ) class softmax: - - def make_lut(self, device): - key = (device, ) - if key not in self.lut_cache: - self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) - return self.lut_cache[key] - - def __init__(self, layout, block): + def __init__(self, layout, block, device, is_dense=False): self.spdims = layout.shape self.layout = layout self.block = block - self.lut_cache = dict() + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense - def __call__( - self, x, scale=1., rpe=None, - key_padding_mask=None, attn_mask=None, - key_padding_mask_mode='add', attn_mask_mode='add', - is_causal=False - ): - if rpe is not None and rpe.dtype != x.dtype: - raise ValueError('relative position embedding must be %s' % x.dtype) - if attn_mask is not None and attn_mask.dtype != x.dtype: - raise ValueError('Attention mask must be %s' % x.dtype) - if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: - raise ValueError('Key padding mask must be %s' % x.dtype) - lut, maxlut = self.make_lut(x.device) - x = _softmax.apply( - x, scale, rpe, - key_padding_mask, attn_mask, - key_padding_mask_mode, attn_mask_mode, - is_causal, - self.spdims, self.block, - lut, maxlut + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError("relative position embedding must be %s" % a.dtype) + a = _softmax.apply( + a, scale, rel_logits, is_causal, + self.spdims, self.block, self.lut, self.maxlut, self.is_dense, ) - return x + return a diff --git a/python/triton/testing.py b/python/triton/testing.py index 199226ea1..c720f64cf 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -32,6 +32,19 @@ def sparsify_tensor(x, mask, block): return ret +def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None): + if data is None: + data = torch.randn(shape, dtype=torch.float32, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().float() + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + def cutlass_matmul(a, b): if _cutlass is None: raise RuntimeError("Cannot find cutlass library") From 7b48340ffddd7d2624b0330b219eb05b673c086b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 6 Feb 2022 19:11:33 -0800 Subject: [PATCH 063/215] [CI] Some fixes for the build (#451) --- python/src/triton.cc | 2 +- python/tutorials/03-matrix-multiplication.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index e9c5e637c..e60f3be37 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -150,7 +150,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f for(int i = 0; i < len; i++){ cache_key += "_"; py::int_ py_i = py::int_(i); - bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end(); + bool specialize = !do_not_specialize.contains(py_i); py::object arg = args[i]; auto arg_ptr = arg.ptr(); diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index ddfe9c0bc..f773a3787 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -169,8 +169,6 @@ import triton.language as tl ], key=['M', 'N', 'K'], ) -# % -# We can now define our kernel as normal, using all the techniques presented above @triton.jit def matmul_kernel( # Pointers to matrices From 822ddcd14b4aaf16d83cfea9e8a9f280ebbd7514 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 8 Feb 2022 11:28:18 -0800 Subject: [PATCH 064/215] [DOCS] Added versioning (#453) --- docs/_templates/versions.html | 27 ++++++++++++++++++++++++ docs/conf.py | 39 +++++++++++++++++++++++------------ 2 files changed, 53 insertions(+), 13 deletions(-) create mode 100644 docs/_templates/versions.html diff --git a/docs/_templates/versions.html b/docs/_templates/versions.html new file mode 100644 index 000000000..c49f844c4 --- /dev/null +++ b/docs/_templates/versions.html @@ -0,0 +1,27 @@ +{%- if current_version %} +
+ + Other Versions + v: {{ current_version.name }} + + +
+ {%- if versions.tags %} +
+
Tags
+ {%- for item in versions.tags %} +
{{ item.name }}
+ {%- endfor %} +
+ {%- endif %} + {%- if versions.branches %} +
+
Branches
+ {%- for item in versions.branches %} +
{{ item.name }}
+ {%- endfor %} +
+ {%- endif %} +
+
+{%- endif %} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 67a14f47a..f63627b62 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,25 +66,38 @@ def setup(app): import sys import os sys.path.insert(0, os.path.abspath('../python/')) -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion'] autosummary_generate = True +# versioning config +smv_tag_whitelist = r'^(v1.1.1)$' +smv_branch_whitelist = r'^master$' +smv_remote_whitelist = None +smv_released_pattern = r'^tags/.*$' +smv_outputdir_format = '{ref.name}' +smv_prefer_remote_refs = False + # Sphinx gallery -extensions += ['sphinx_gallery.gen_gallery'] -from sphinx_gallery.sorting import FileNameSortKey -sphinx_gallery_conf = { - 'examples_dirs': '../python/tutorials/', - 'gallery_dirs': 'getting-started/tutorials', - 'filename_pattern': '', - 'ignore_pattern': r'__init__\.py', - 'within_subsection_order': FileNameSortKey, - 'reference_url': { - 'sphinx_gallery': None, - } -} +# extensions += ['sphinx_gallery.gen_gallery'] +# from sphinx_gallery.sorting import FileNameSortKey +# sphinx_gallery_conf = { +# 'examples_dirs': '../python/tutorials/', +# 'gallery_dirs': 'getting-started/tutorials', +# 'filename_pattern': '', +# 'ignore_pattern': r'__init__\.py', +# 'within_subsection_order': FileNameSortKey, +# 'reference_url': { +# 'sphinx_gallery': None, +# } +# } # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] +html_sidebars = { + '**': [ + '_templates/versions.html', + ], +} # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: From 077d6c8ff0c2ed72a662126b339c1e5581de59a0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 8 Feb 2022 11:42:39 -0800 Subject: [PATCH 065/215] [DOCS] re-activated tutorials --- docs/conf.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f63627b62..09d355dcd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -78,18 +78,18 @@ smv_outputdir_format = '{ref.name}' smv_prefer_remote_refs = False # Sphinx gallery -# extensions += ['sphinx_gallery.gen_gallery'] -# from sphinx_gallery.sorting import FileNameSortKey -# sphinx_gallery_conf = { -# 'examples_dirs': '../python/tutorials/', -# 'gallery_dirs': 'getting-started/tutorials', -# 'filename_pattern': '', -# 'ignore_pattern': r'__init__\.py', -# 'within_subsection_order': FileNameSortKey, -# 'reference_url': { -# 'sphinx_gallery': None, -# } -# } +extensions += ['sphinx_gallery.gen_gallery'] +from sphinx_gallery.sorting import FileNameSortKey +sphinx_gallery_conf = { + 'examples_dirs': '../python/tutorials/', + 'gallery_dirs': 'getting-started/tutorials', + 'filename_pattern': '', + 'ignore_pattern': r'__init__\.py', + 'within_subsection_order': FileNameSortKey, + 'reference_url': { + 'sphinx_gallery': None, + } +} # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] From 2fdf0a4fe8c59ab3660acd310a52ec87f40bed9c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 8 Feb 2022 11:45:21 -0800 Subject: [PATCH 066/215] [DOCS] changed build command --- .github/workflows/documentation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d4ba42733..6d06c96d1 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -30,7 +30,7 @@ jobs: - name: Build docs run: | cd docs - make html + sphinx-multiversion . _build/html/ - name: Publish docs run: | From 4941bc7001f84485837212e8e6d8eb5d128ca4f1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 8 Feb 2022 16:53:56 -0800 Subject: [PATCH 067/215] [DOCS] Some more fixes (#455) --- docs/conf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 09d355dcd..4d62c5650 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,14 +34,17 @@ def process_sig(app, what, name, obj, options, signature, return_annotation): def setup(app): """Customize function args retrieving to get args under decorator.""" import sphinx - import triton + import os app.connect("autodoc-process-signature", process_sig) + os.system("pip install -e ../python") + def forward_jit_fn(func): old = func def wrapped(obj, **kwargs): + import triton if isinstance(obj, triton.code_gen.JITFunction): obj = obj.fn return old(obj) @@ -52,6 +55,7 @@ def setup(app): old_documenter = sphinx.ext.autosummary.get_documenter def documenter(app, obj, parent): + import triton if isinstance(obj, triton.code_gen.JITFunction): obj = obj.fn return old_documenter(app, obj, parent) @@ -70,7 +74,7 @@ extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosu autosummary_generate = True # versioning config -smv_tag_whitelist = r'^(v1.1.1)$' +smv_tag_whitelist = r'^(v1.1.2)$' smv_branch_whitelist = r'^master$' smv_remote_whitelist = None smv_released_pattern = r'^tags/.*$' From 40093a9878be6d5c646447ec1f608ad322a0fbb9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 9 Feb 2022 01:32:41 -0800 Subject: [PATCH 068/215] [DOCS] Multiple versions are now supported (#457) --- .github/workflows/documentation.yml | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 6d06c96d1..e921709ba 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -12,7 +12,7 @@ jobs: steps: - + - name: Checkout gh-pages uses: actions/checkout@v1 with: @@ -21,21 +21,30 @@ jobs: - name: Checkout branch uses: actions/checkout@v1 - - name: Install Triton - run: | - alias python='python3' - cd python - pip3 install -e '.[tutorials]' - - name: Build docs run: | + git fetch origin master:master cd docs sphinx-multiversion . _build/html/ - name: Publish docs run: | + git branch + # update docs + rm -r /tmp/triton-docs; + mkdir /tmp/triton-docs; + mv docs/_build/html/* /tmp/triton-docs/ git checkout gh-pages - sh ./update-website.sh + cp -r CNAME /tmp/triton-docs/ + cp -r index.html /tmp/triton-docs/ + cp -r .nojekyll /tmp/triton-docs/ + rm -r * + cp -r /tmp/triton-docs/* . + # ln -s master/index.html . + # mv master docs + git add . + git commit -am "[GH-PAGES] Updated website" + # publish docs eval `ssh-agent -s` DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }} git remote set-url origin git@github.com:openai/triton.git From 9b100302d3818e8ac396c743cd691f8147a5edb5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 10 Feb 2022 01:57:39 -0800 Subject: [PATCH 069/215] [FRONTEND] Now using pybind11 to release GIL (#458) --- python/src/triton.cc | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index e60f3be37..3410df6b8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -328,13 +328,10 @@ void init_triton_runtime(py::module &&m) { if(grid_0*grid_1*grid_2 > 0) { // release the gil in case the enqueue blocks // cuda will block if too many ops are enqueued - Py_BEGIN_ALLOW_THREADS - + py::gil_scoped_release allow_threads; drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); - - Py_END_ALLOW_THREADS } return bin; }); @@ -394,14 +391,13 @@ void init_triton_runtime(py::module &&m) { size_t args_size = args.size(); // release the gil in case the enqueue blocks // cuda will block if too many ops are enqueued - Py_BEGIN_ALLOW_THREADS + py::gil_scoped_release allow_threads; if(backend == HOST) host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == CUDA) cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == ROCM) hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); - Py_END_ALLOW_THREADS }); @@ -468,7 +464,7 @@ std::tuple cu_compile_ttir(const std::string& name, asm_map_t &asm_map){ int n_shared_bytes; - Py_BEGIN_ALLOW_THREADS + py::gil_scoped_release allow_threads; llvm::LLVMContext ctx; // device properties CUdevice dev = (CUdevice)device; @@ -494,7 +490,6 @@ std::tuple cu_compile_ttir(const std::string& name, py::bytes bytes(cubin); asm_map["cubin"] = bytes; } - Py_END_ALLOW_THREADS return std::make_tuple(name, asm_map, n_shared_bytes); } From a9dfdcaaa987c84e4f66cacb38fdcb7d1d9f1837 Mon Sep 17 00:00:00 2001 From: daadaada Date: Sat, 12 Feb 2022 14:34:42 +0800 Subject: [PATCH 070/215] [FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456) --- python/test/regression/test_performance.py | 126 ++++++++++++++------- python/triton/code_gen.py | 14 +-- python/triton/ops/matmul.py | 26 ++++- python/triton/ops/matmul_perf_model.py | 54 +++++++-- python/triton/testing.py | 50 +++++++- 5 files changed, 201 insertions(+), 69 deletions(-) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 39299a89a..1df3a0b49 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -6,6 +6,9 @@ import torch import triton import triton.language as tl +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops + +DEVICE_NAME = 'v100' ####################### # Utilities @@ -25,42 +28,76 @@ def nvsmi(attrs): # Matrix Multiplication ####################### +sm_clocks = {'v100': 1350, 'a100': 1350} +mem_clocks = {'v100': 877, 'a100': 1215} + matmul_data = { - # square - (256, 256, 256): {'v100': 0.027}, - (512, 512, 512): {'v100': 0.158}, - (1024, 1024, 1024): {'v100': 0.466}, - (2048, 2048, 2048): {'v100': 0.680}, - (4096, 4096, 4096): {'v100': 0.831}, - (8192, 8192, 8192): {'v100': 0.849}, - # tall-skinny - (16, 1024, 1024): {'v100': 0.0128}, - (16, 4096, 4096): {'v100': 0.0883}, - (16, 8192, 8192): {'v100': 0.101}, - (64, 1024, 1024): {'v100': 0.073}, - (64, 4096, 4096): {'v100': 0.270}, - (64, 8192, 8192): {'v100': 0.360}, - (1024, 64, 1024): {'v100': 0.0692}, - (4096, 64, 4096): {'v100': 0.264}, - (8192, 64, 8192): {'v100': 0.323}, + 'v100': { + # square + (256, 256, 256): {'float16': 0.027}, + (512, 512, 512): {'float16': 0.158}, + (1024, 1024, 1024): {'float16': 0.466}, + (2048, 2048, 2048): {'float16': 0.680}, + (4096, 4096, 4096): {'float16': 0.831}, + (8192, 8192, 8192): {'float16': 0.849}, + # tall-skinny + (16, 1024, 1024): {'float16': 0.0128}, + (16, 4096, 4096): {'float16': 0.0883}, + (16, 8192, 8192): {'float16': 0.101}, + (64, 1024, 1024): {'float16': 0.073}, + (64, 4096, 4096): {'float16': 0.270}, + (64, 8192, 8192): {'float16': 0.459}, + (1024, 64, 1024): {'float16': 0.0692}, + (4096, 64, 4096): {'float16': 0.264}, + (8192, 64, 8192): {'float16': 0.452}, + }, + 'a100': { + (256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006}, + (512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030}, + (1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385}, + (4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711}, + (8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860}, + # tall-skinny + (16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005}, + (16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259}, + (16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431}, + (64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169}, + (64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174}, + (1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017}, + (4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177}, + } # # deep reductions - # (64 , 64 , 16384) : {'v100': 0.}, - # (64 , 64 , 65536) : {'v100': 0.}, - # (256 , 256 , 8192 ) : {'v100': 0.}, - # (256 , 256 , 32768) : {'v100': 0.}, + # (64 , 64 , 16384) : {'a100': 0.}, + # (64 , 64 , 65536) : {'a100': 0.}, + # (256 , 256 , 8192 ) : {'a100': 0.}, + # (256 , 256 , 32768) : {'a100': 0.}, } -@pytest.mark.parametrize('M, N, K', matmul_data.keys()) -def test_matmul(M, N, K): +@pytest.mark.parametrize('M, N, K, dtype_str', + [(M, N, K, dtype_str) + for M, N, K in matmul_data[DEVICE_NAME].keys() + for dtype_str in ['float16']]) +def test_matmul(M, N, K, dtype_str): + if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': + pytest.skip('Only test float32 & int8 on a100') + dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] torch.manual_seed(0) - ref_gpu_util = matmul_data[(M, N, K)]['v100'] + ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - ref_sm_clock = 1350 - max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock + ref_sm_clock = sm_clocks[DEVICE_NAME] + max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' - a = torch.randn((M, K), dtype=torch.float16, device='cuda') - b = torch.randn((K, N), dtype=torch.float16, device='cuda') + if dtype == torch.int8: + a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') + b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') + b = b.t() # only test row-col layout + else: + a = torch.randn((M, K), dtype=dtype, device='cuda') + b = torch.randn((K, N), dtype=dtype, device='cuda') fn = lambda: triton.ops.matmul(a, b) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) cur_gpu_perf = 2. * M * N * K / ms * 1e-9 @@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { - 1024 * 16: {'v100': 0.0219}, - 1024 * 64: {'v100': 0.0791}, - 1024 * 256: {'v100': 0.243}, - 1024 * 1024: {'v100': 0.534}, - 1024 * 4096: {'v100': 0.796}, - 1024 * 16384: {'v100': 0.905}, - 1024 * 65536: {'v100': 0.939}, + 'v100': { + 1024 * 16: 0.0219, + 1024 * 64: 0.0791, + 1024 * 256: 0.243, + 1024 * 1024: 0.534, + 1024 * 4096: 0.796, + 1024 * 16384: 0.905, + 1024 * 65536: 0.939, + }, + 'a100': { + 1024 * 16: 0.008, + 1024 * 64: 0.034, + 1024 * 256: 0.114, + 1024 * 1024: 0.315, + 1024 * 4096: 0.580, + 1024 * 16384: 0.782, + 1024 * 65536: 0.850, + } } -@pytest.mark.parametrize('N', elementwise_data.keys()) +@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) def test_elementwise(N): torch.manual_seed(0) - ref_gpu_util = elementwise_data[N]['v100'] + ref_gpu_util = elementwise_data[DEVICE_NAME][N] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] - ref_mem_clock = 877 - max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3 + ref_mem_clock = mem_clocks[DEVICE_NAME] + max_gpu_perf = get_dram_gbps() assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index dc2b375b8..894b3f1e3 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -811,12 +811,12 @@ class Autotuner: # prune configs if prune_configs_by: perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'prune_num_stages_by' in prune_configs_by: - prune_num_stages_by = prune_configs_by['prune_num_stages_by'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] else: - perf_model, top_k, prune_num_stages_by = None, None, None + perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k - self.prune_num_stages_by = prune_num_stages_by + self.early_config_prune = early_config_prune def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -844,8 +844,8 @@ class Autotuner: if key not in self.cache: # prune configs pruned_configs = self.configs - if self.prune_num_stages_by: - pruned_configs = self.prune_num_stages_by(self.configs, self.nargs) + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) if self.perf_model: top_k = self.configs_top_k if isinstance(top_k, float) and top_k <= 1.0: @@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] """ diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9466b9ba7..f1ac78849 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from .matmul_perf_model import estimate_matmul_time, prune_num_stages +from .matmul_perf_model import early_config_prune, estimate_matmul_time def init_to_zero(name): @@ -27,7 +27,7 @@ def get_configs_io_bound(): @triton.heuristics({ - 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, }) @triton.autotune( configs=[ @@ -41,10 +41,20 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ] + get_configs_io_bound(), key=['M', 'N', 'K'], prune_configs_by={ - 'prune_num_stages_by': prune_num_stages, + 'early_config_prune': early_config_prune, 'perf_model': estimate_matmul_time, 'top_k': 10 }, @@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K, # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k in range(K, 0, -BLOCK_K * SPLIT_K): if EVEN_K: a = tl.load(A) @@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function): _, N = b.shape # allocates output c = torch.empty((M, N), device=device, dtype=a.dtype) + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - GROUP_M=8) + GROUP_M=8, ACC_TYPE=ACC_TYPE) return c @staticmethod diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 98a85bc85..9c10b88d8 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -4,20 +4,36 @@ import torch import triton import triton._C.libtriton.triton as _triton -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops -def get_tensorcore_tflops(backend, device, num_ctas, num_warps): +def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) return tflops +def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) + return tflops + + +def get_tflops(backend, device, num_ctas, num_warps, dtype): + cc = _triton.runtime.cc(backend, device) + if cc < 80 and dtype == torch.float32: + return get_simd_tflops() + return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) + + def estimate_matmul_time( # backend, device, num_warps, num_stages, + A, B, C, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, debug=False, **kwargs @@ -26,6 +42,8 @@ def estimate_matmul_time( = max(compute, loading) + store ''' backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_n = triton.cdiv(N, BLOCK_N) @@ -37,7 +55,7 @@ def estimate_matmul_time( # time to compute total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS - tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + tput = get_tflops(backend, device, num_ctas, num_warps, dtype) compute_ms = total_ops / tput # time to load data @@ -48,10 +66,10 @@ def estimate_matmul_time( dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) # assume 80% of (following) loads are in L2 cache - load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2) - load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1) - load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1)) - load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1) + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) # total total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) @@ -60,7 +78,7 @@ def estimate_matmul_time( # estimate storing time store_bw = dram_bw * 0.6 # :o - store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB if SPLIT_K == 1: store_ms = store_c_dram / store_bw else: @@ -78,14 +96,28 @@ def estimate_matmul_time( return total_time_ms -def prune_num_stages(configs, named_args): +def early_config_prune(configs, named_args): backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() cc = _triton.runtime.cc(backend, device) # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs # Some dtypes do not allow atomic_add - if named_args['A'].dtype == torch.bfloat16: + if dtype not in [torch.float16, torch.float32]: configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) diff --git a/python/triton/testing.py b/python/triton/testing.py index c720f64cf..fbca719ff 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None): device = torch.cuda.current_device() mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device) - bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s return bw_gbps -def get_max_tensorcore_tflops(backend, device): +def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None): + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz - # assume fp32 += fp16*fp16 + if not clock_rate: + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz cc = _triton.runtime.cc(backend, device) if cc < 80: + assert dtype == torch.float16 ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: - ops_per_sub_core = 512 - tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024) + if dtype == torch.float32: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 512 + elif dtype == torch.int8: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None): + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + cc = _triton.runtime.cc(backend, device) + if cc < 80: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops From 98ed7db8c1f3e7de6e67cbb3838f0692fb541be4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 24 Feb 2022 14:56:24 -0800 Subject: [PATCH 071/215] [CODEGEN] Improvements and bugfixes (#463) --- CMakeLists.txt | 11 ++++- include/triton/driver/llvm.h | 3 +- include/triton/ir/builder.h | 6 +-- include/triton/ir/dispatch.h | 3 +- include/triton/ir/instructions.h | 32 ++++++++++---- lib/codegen/selection/generator.cc | 16 ++++--- lib/codegen/transform/peephole.cc | 3 +- lib/codegen/transform/pipeline.cc | 10 ++--- lib/driver/llvm.cc | 69 ++++++++++++++++++++---------- lib/ir/builder.cc | 12 +++--- lib/ir/dispatch.cc | 32 ++++++++++---- lib/ir/instructions.cc | 30 +++++++------ python/src/triton.cc | 4 +- python/triton/language/core.py | 4 +- 14 files changed, 154 insertions(+), 81 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f44c35aa7..c3aadf9c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,15 @@ libLLVMBinaryFormat.a libLLVMAMDGPUInfo.a libLLVMSupport.a libLLVMDemangle.a +libLLVMPasses.a +libLLVMAnalysis.a +libLLVMTransformUtils.a +libLLVMScalarOpts.a +libLLVMTransformUtils.a +libLLVMipo.a +libLLVMObjCARCOpts.a +libLLVMCoroutines.a +libLLVMAnalysis.a ) endif() include_directories("${LLVM_INCLUDE_DIRS}") @@ -148,7 +157,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) if(WIN32) target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32 else() - target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY}) + target_link_libraries(triton ${LLVM_LIBRARIES} z) endif() diff --git a/include/triton/driver/llvm.h b/include/triton/driver/llvm.h index 89dc98169..c0c1c0f37 100644 --- a/include/triton/driver/llvm.h +++ b/include/triton/driver/llvm.h @@ -9,8 +9,9 @@ namespace triton{ namespace driver{ void init_llvm(); +std::string path_to_ptxas(int& version); std::string llir_to_ptx(llvm::Module* module, int cc, int version); -std::string ptx_to_cubin(const std::string& ptx, int cc); +std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc); CUmodule ptx_to_cumodule(const std::string& ptx, int cc); std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); hipModule_t amdgpu_to_hipmodule(const std::string& path); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 67ab47c90..2b6bc6ab3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -136,9 +136,9 @@ public: value *create_xor(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs); // Input/Output - value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, bool is_volatile); + value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_store(value *ptr, value *val); - value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile); + value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); @@ -163,7 +163,7 @@ public: // These have no place in the IR, and hopefully they can be removed at some point value *create_umulhi(value* lhs, value* rhs); value *create_copy_to_shared(value *arg); - value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); + value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY); value *create_copy_from_shared(value *arg); value *create_barrier(const std::string &name = ""); value *create_async_wait(int N); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 7dc8ab0ed..ef14043dd 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -69,7 +69,8 @@ struct dispatch{ static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, int is_volatile, ir::builder *builder); + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, + const std::string& eviction_policy, int is_volatile, ir::builder *builder); static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 5af077e8f..0fb85db02 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -406,13 +406,20 @@ public: NONE=0, CA, CG, - }; + }; + + enum EVICTION_POLICY : uint32_t { + NORMAL=0, + EVICT_FIRST, + EVICT_LAST, + }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } + EVICTION_POLICY get_eviction_policy() const { return eviction_; } bool get_is_volatile() const { return is_volatile_; } protected: - load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, + load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name = "", instruction *next = nullptr); std::string get_cache_modifier_repr() const { @@ -420,6 +427,11 @@ protected: if (cache_ == CG) return ".cg"; return ""; } + std::string get_eviction_policy_repr() const { + if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; + if (eviction_ == EVICT_LAST) return ".L2::evict_last"; + } + EVICTION_POLICY eviction_; CACHE_MODIFIER cache_; std::string get_volatile_repr() { @@ -435,11 +447,12 @@ private: class unmasked_load_inst: public load_inst { private: std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } - unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next); + unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); public: static unmasked_load_inst* create(value *ptr, - CACHE_MODIFIER cache, bool is_volatile, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_load_inst) @@ -450,7 +463,7 @@ public: class masked_load_inst: public load_inst { private: std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } - masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, + masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); public: @@ -459,7 +472,8 @@ public: value *get_false_value_operand() { return get_operand(2); } // factory method static masked_load_inst* create(value *ptr, value *mask, value *false_value, - CACHE_MODIFIER cache, bool is_volatile, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + bool is_volatile, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_inst) @@ -470,8 +484,9 @@ public: class masked_load_async_inst: public load_inst { private: std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } - masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, - const std::string &name, instruction *next); + masked_load_async_inst(value *ptr, value *mask, value *false_value, + CACHE_MODIFIER cache, EVICTION_POLICY eviction, + const std::string &name, instruction *next); public: // accessors @@ -480,6 +495,7 @@ public: // factory method static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, + EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_load_async_inst) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d2ebce1c6..b4f1dd41e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -119,7 +119,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) #define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) @@ -576,18 +576,19 @@ void generator::visit_cast_inst(ir::cast_inst* x) { // <> BF16 if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ // FP32 -> BF16 - if(op_sca_ty->is_fp32_ty()) - // for(size_t i = 0; i < x_idxs.size(); i++) - // vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + if(op_sca_ty->is_fp32_ty()){ for (indices_t idx: idxs_.at(x)) { Value *arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); } + return; + } // BF16 -> FP32 - if(ret_sca_ty->is_fp32_ty()) + if(ret_sca_ty->is_fp32_ty()){ for(size_t i = 0; i < x_idxs.size(); i++) vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); - return; + return; + } } @@ -697,12 +698,13 @@ void generator::visit_load_inst(ir::load_inst* x){ std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate asm_oss << " ld"; -// std::cout << x->get_is_volatile() << std::endl; if(x->get_is_volatile()) asm_oss << ".volatile"; asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index e30ab9b35..0961efc9c 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -123,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build int nts = layout->nts(layout->get_order()[0]); int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(nts*dtsize >= 4){ - ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier()); + ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy()); copy_to_shared->replace_all_uses_with(new_load); return true; } @@ -215,6 +215,7 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b if_value->get_mask_operand(), select->get_else_value_op(), if_value->get_cache_modifier(), + if_value->get_eviction_policy(), if_value->get_is_volatile()); select->replace_all_uses_with(new_load); return true; diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index eb3fe6164..c85ba43a1 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -178,7 +178,7 @@ void pipeline::run(ir::module &mod) { false_value = remat_false_value; } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_is_volatile()); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); for (int stage = 1; stage < num_stages-1; ++stage) { // mask is the loop condition of the previous iteration @@ -193,7 +193,7 @@ void pipeline::run(ir::module &mod) { first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); false_value = remat_false_value; } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_is_volatile()); + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); } // create new phis for induction variables @@ -222,7 +222,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // phi node @@ -257,7 +257,7 @@ void pipeline::run(ir::module &mod) { } else false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // pre-fetch next iteration builder.set_insert_point(block->get_inst_list().back()); ir::value* next_ptr = ptr->get_value_for_block(block); @@ -268,7 +268,7 @@ void pipeline::run(ir::module &mod) { next_mask = builder.create_and(next_mask, remat_mask); false_value = remat_false_value; } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_is_volatile()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); // phi node builder.set_insert_point(block->get_first_non_phi()); ir::phi_node* new_load = builder.create_phi(ty, 2); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index be8200b86..e7bb47bef 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -59,6 +59,13 @@ #include "llvm/Analysis/TargetLibraryInfo.h" // end AMD stuff +extern "C"{ + int set_curterm(char* nterm){ return 0; } + int del_curterm(char* nterm){ return 0; } + int tigetnum(char *capname) { return 0; } + int setupterm(char *term, int fildes, int *errret) { return 0; } +} + namespace triton{ namespace driver{ @@ -77,6 +84,7 @@ void init_llvm() { } } + /* ------------------------ */ // CUDA // /* ------------------------ */ @@ -89,7 +97,42 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s return true; } +std::string path_to_ptxas(int& version) { + std::string ret; + // search pathes for ptxas + std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; + std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); + if(!triton_ptxas.empty()) + ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); + // see what path for ptxas are valid + std::vector working_ptxas; + for(std::string prefix: ptxas_prefixes){ + std::string ptxas = prefix + "ptxas"; + bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; + if(works) + working_ptxas.push_back(ptxas); + } + // error if no working ptxas was found + if(working_ptxas.empty()) + throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" + " but a working version could not be found."); + std::string ptxas = working_ptxas.front(); + // parse version + std::regex version_regex("release (\\d+)\\.(\\d+)"); + std::smatch match; + if(std::regex_search(ret, match, version_regex)){ + int major = std::stoi(match[1]); + int minor = std::stoi(match[2]); + version = major*1000 + minor*10; + } + else + throw std::runtime_error("couldn't parse ptxas version: " + ret); + return ptxas; +} + + int vptx(int version){ + if(version >= 11040) return 74; if(version >= 11030) return 73; if(version >= 11020) return 72; if(version >= 11010) return 71; @@ -103,7 +146,7 @@ int vptx(int version){ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; - int max_nvvm_ptx = 64; + int max_nvvm_ptx = 74; // options auto options = llvm::cl::getRegisteredOptions(); auto* short_ptr = static_cast*>(options["nvptx-short-ptr"]); @@ -120,7 +163,8 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ std::string triple = "nvptx64-nvidia-cuda"; std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc)); std::string layout = ""; - std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); + std::string features = ""; + // std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; @@ -164,26 +208,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ return result; } -std::string ptx_to_cubin(const std::string& ptx, int cc) { - std::string version; - // search pathes for ptxas - std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; - std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); - if(!triton_ptxas.empty()) - ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); - // see what path for ptxas are valid - std::vector working_ptxas; - for(std::string prefix: ptxas_prefixes){ - std::string ptxas = prefix + "ptxas"; - bool works = tools::exec(ptxas + " --version 2>&1", version) == 0; - if(works) - working_ptxas.push_back(ptxas); - } - // error if no working ptxas was found - if(working_ptxas.empty()) - throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" - " but a working version could not be found."); - std::string ptxas = working_ptxas.front(); +std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; char _flog[L_tmpnam]; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index bf59b5b39..fff73e665 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -291,16 +291,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile){ - return insert(unmasked_load_inst::create(ptr, cache, is_volatile)); +value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ + return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); } value *builder::create_store(value *ptr, value *val){ return insert(unmasked_store_inst::create(ptr, val)); } -value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile){ - return insert(masked_load_inst::create(ptr, mask, false_value, cache, is_volatile)); +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ + return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); } value *builder::create_masked_store(value *ptr, value *val, value *mask){ @@ -404,8 +404,8 @@ value *builder::create_copy_from_shared(value *arg) { return insert(copy_from_shared_inst::create(arg)); } -value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) { - return insert(masked_load_async_inst::create(ptr, mask, false_value, cache)); +value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) { + return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction)); } value *builder::create_barrier(const std::string &name) { diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 5a2d33a98..664fbb983 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -506,6 +506,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build return input; ir::type *src_sca_ty = src_ty->get_scalar_ty(); ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // + if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || + (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ + return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); + } // FP Truncation bool truncate_fp = src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_floating_point_ty() && @@ -569,18 +574,17 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build // Memory Operators //===----------------------------------------------------------------------===// -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, int is_volatile, ir::builder* builder) { +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); if(ptr->get_type()->is_block_ty()){ - if(mask){ + if(mask) mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - } - if(other){ + if(other) other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); - other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - } } + if(other) + other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); // treat bool* as int8* @@ -599,8 +603,20 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con else throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); } + // eviction policy + load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default + if(!eviction_policy.empty()){ + if (eviction_policy == "evict_last") + eviction = load_inst::EVICT_LAST; + else if(eviction_policy == "evict_first") + eviction = load_inst::EVICT_FIRST; + else + throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); + } + + if (!mask && !other) - return builder->create_load(ptr, cache, is_volatile); + return builder->create_load(ptr, cache, eviction, is_volatile); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); auto shape = ptr->get_type()->get_block_shapes(); @@ -609,7 +625,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, con if(ptr->get_type()->is_block_ty()) other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); } - return builder->create_masked_load(ptr, mask, other, cache, is_volatile); + return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); } ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index ac5a17289..c225b315f 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -434,8 +434,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n { } // load_inst -load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), is_volatile_(is_volatile) +load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile) { } // load @@ -448,44 +448,46 @@ type *load_inst::get_pointee_type(type *ty) { } // unmasked_load -unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, is_volatile, name, next) { +unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) { set_operand(0, ptr); } -unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, bool is_volatile, const std::string &name, instruction *next) { - return new unmasked_load_inst(ptr, cache, is_volatile, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next); } // masked load -masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, bool is_volatile, +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, + bool is_volatile, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD, 3, cache, is_volatile, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, bool is_volatile, + load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, + bool is_volatile, const std::string &name, instruction *next) { - return new masked_load_inst(ptr, mask, false_value, cache, is_volatile, name, next); + return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next); } // masked load async masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, const std::string &name, instruction *next) - : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, false, name, next) { + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) { set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, - load_inst::CACHE_MODIFIER cache, + load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, const std::string &name, instruction *next) { - return new masked_load_async_inst(ptr, mask, false_value, cache, name, next); + return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next); } // store diff --git a/python/src/triton.cc b/python/src/triton.cc index 3410df6b8..c5c5b196f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -472,7 +472,7 @@ std::tuple cu_compile_ttir(const std::string& name, size_t minor = cuGetInfo(dev); size_t cc = major*10 + minor; int version; - drv::dispatch::cuDriverGetVersion(&version); + std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); @@ -485,7 +485,7 @@ std::tuple cu_compile_ttir(const std::string& name, std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version); asm_map["ptx"] = py::cast(ptx); // PTX -> Binary - std::string cubin = drv::ptx_to_cubin(ptx, cc); + std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); if(!cubin.empty()){ py::bytes bytes(cubin); asm_map["cubin"] = bytes; diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f4188f1c7..df25e59fb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -556,7 +556,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin -def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _builder=None): +def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. @@ -573,7 +573,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", volatile=False, _bui :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, volatile, _builder) + return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin From d9dd97492f228020573b39a9cec14ee3b8776957 Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 25 Feb 2022 08:07:10 +0800 Subject: [PATCH 072/215] Use unique_ptr in ir::context_impl (#462) Co-authored-by: Philippe Tillet --- include/triton/ir/context.h | 2 -- include/triton/ir/context_impl.h | 18 ++++++++---------- lib/ir/constant.cc | 20 ++++++++++---------- lib/ir/type.cc | 12 ++++++------ 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/include/triton/ir/context.h b/include/triton/ir/context.h index 55edf31cd..d824c98b6 100644 --- a/include/triton/ir/context.h +++ b/include/triton/ir/context.h @@ -9,7 +9,6 @@ namespace triton{ namespace ir{ -class builder; class type; class context_impl; @@ -21,7 +20,6 @@ public: context& operator=(const context&) = delete; public: - ir::builder* builder = nullptr; std::shared_ptr p_impl; }; diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index e43b5ad57..081ea249d 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -3,17 +3,15 @@ #ifndef _TRITON_IR_CONTEXT_IMPL_H_ #define _TRITON_IR_CONTEXT_IMPL_H_ -#include #include "triton/ir/type.h" +#include "triton/ir/constant.h" +#include +#include namespace triton{ namespace ir{ class context; -class constant; -class constant_int; -class constant_fp; -class undef_value; /* Context impl */ class context_impl { @@ -30,16 +28,16 @@ public: integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types - std::map, pointer_type*> ptr_tys; + std::map, std::unique_ptr> ptr_tys; // Block types - std::map, block_type*> block_tys; + std::map, std::unique_ptr> block_tys; // Int constants - std::map, constant_int*> int_constants_; + std::map, std::unique_ptr> int_constants_; // Float constants - std::map, constant_fp*> fp_constants_; + std::map, std::unique_ptr> fp_constants_; // undef values - std::map uv_constants_; + std::map> uv_constants_; }; diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index b2a50c3be..ab1f6f497 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -47,10 +47,10 @@ constant_int *constant_int::get(type *ty, uint64_t value) { if (!ty->is_integer_ty()) throw std::runtime_error("Cannot create constant_int with non integer ty"); context_impl *impl = ty->get_context().p_impl.get(); - constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; - if(cst == nullptr) - cst = new constant_int(ty, value); - return cst; + std::unique_ptr &cst = impl->int_constants_[std::make_pair(ty, value)]; + if(!cst) + cst.reset(new constant_int(ty, value)); + return cst.get(); } @@ -73,10 +73,10 @@ constant *constant_fp::get_zero_value_for_negation(type *ty) { constant *constant_fp::get(type *ty, double v){ context_impl *impl = ty->get_context().p_impl.get(); - constant_fp *&result = impl->fp_constants_[std::make_pair(ty, v)]; + std::unique_ptr &result = impl->fp_constants_[std::make_pair(ty, v)]; if(!result) - result = new constant_fp(ty, v); - return result; + result.reset(new constant_fp(ty, v)); + return result.get(); } @@ -86,10 +86,10 @@ undef_value::undef_value(type *ty) undef_value *undef_value::get(type *ty) { context_impl *impl = ty->get_context().p_impl.get(); - undef_value *&result = impl->uv_constants_[ty]; + std::unique_ptr &result = impl->uv_constants_[ty]; if(!result) - result = new undef_value(ty); - return result; + result.reset(new undef_value(ty)); + return result.get(); } /* global value */ diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 74066a65a..7e4e4e5d7 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -167,10 +167,10 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); - pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; + std::unique_ptr &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; if(!entry) - entry = new pointer_type(elt_ty, address_space); - return entry; + entry.reset(new pointer_type(elt_ty, address_space)); + return entry.get(); } //===----------------------------------------------------------------------===// @@ -217,10 +217,10 @@ block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) { assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); - block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; + std::unique_ptr &entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; if(!entry) - entry = new block_type(elt_ty, shapes); - return entry; + entry.reset(new block_type(elt_ty, shapes)); + return entry.get(); } block_type* block_type::get_same_shapes(type *ty, type *ref){ From bb5765df5c293889296468199868134e71b125a8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 3 Mar 2022 22:19:05 -0800 Subject: [PATCH 073/215] [CODEGEN] Now padding shared memory for layout conversion (#468) --- include/triton/codegen/analysis/layout.h | 5 +++ lib/codegen/analysis/layout.cc | 8 ++++ lib/codegen/analysis/swizzle.cc | 51 ++++++++++++------------ lib/codegen/selection/generator.cc | 33 ++++++++++----- 4 files changed, 62 insertions(+), 35 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index b6376d7cc..56fb1e4b9 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -103,6 +103,7 @@ public: int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } + virtual int contig_per_thread(size_t k) = 0; protected: std::vector shape_per_cta_; @@ -181,6 +182,7 @@ public: int wpt(size_t k) { return wpt_.at(k); } int spw(size_t k) { return spw_.at(k); } int rep(size_t k) { return rep_.at(k); } + int contig_per_thread(size_t k) { return contig_per_thread_.at(k); } // helpers for generator.cc std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } @@ -203,6 +205,8 @@ private: std::vector spt_; // repetitions std::vector rep_; + // contiguous per thread + std::vector contig_per_thread_; TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; }; @@ -218,6 +222,7 @@ struct scanline_layout: public distributed_layout { // accessor int mts(size_t k) { return mts_.at(k); } int nts(size_t k) { return nts_.at(k); } + int contig_per_thread(size_t k) { return nts_.at(k); } public: // micro tile size. The size of a tile held by a thread block. diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 587234863..fd0a7879b 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -208,10 +208,12 @@ mma_layout::mma_layout(size_t num_warps, int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; + contig_per_thread_ = {1, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 + contig_per_thread_ = {1, 1}; // rep_ = {2, 2, 1}; } order_ = {0, 1}; @@ -628,6 +630,12 @@ void layouts::run(ir::module &mod) { shape[k] = std::max(in_layout->shape_per_cta(k), out_layout->shape_per_cta(k)); } + auto in_ord = in_layout->get_order(); + auto out_ord = out_layout->get_order(); + int in_vec = in_layout->contig_per_thread(in_ord[0]); + int out_vec = out_layout->contig_per_thread(out_ord[0]); + int pad = std::max(in_vec, out_vec); + shape[out_ord[0]] += pad; layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); tmp_[val] = id; } diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 5737f80a0..414b0e0e5 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -14,41 +14,42 @@ void swizzle::run(ir::module &) { max_phase_.clear(); for(auto &x: layouts_->get_all()){ - shared_layout* layout = dynamic_cast(x.second); - if(!layout) + shared_layout* out_layout = dynamic_cast(x.second); + if(!out_layout) continue; - ir::value* mma_dot_a = layout->hmma_dot_a(); - ir::value* mma_dot_b = layout->hmma_dot_b(); - - if(!mma_dot_a && !mma_dot_b){ - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; - continue; - } - auto ord = layout->get_order(); - scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); + scanline_layout* in_layout = dynamic_cast(out_layout->get_arg_layout()); if(!in_layout) continue; - int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + + ir::value* mma_dot_a = out_layout->hmma_dot_a(); + ir::value* mma_dot_b = out_layout->hmma_dot_b(); + + if(!mma_dot_a && !mma_dot_b){ + per_phase_[out_layout] = 1; + max_phase_[out_layout] = 1; + vec_[out_layout] = 1; + continue; + } + auto ord = out_layout->get_order(); + int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; + per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout]; if(mma_dot_a) - vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); + vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); else - vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); + vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { - if (!layout->allow_swizzle()) { - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; + if (!out_layout->allow_swizzle()) { + per_phase_[out_layout] = 1; + max_phase_[out_layout] = 1; + vec_[out_layout] = 1; } else { - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; - vec_[layout] = layout->get_mma_vec(); + per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout]; + vec_[out_layout] = out_layout->get_mma_vec(); } } } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b4f1dd41e..f8cf08cba 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2377,8 +2377,11 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ } in_ord = in_layout->to_mma() ? out_ord : in_ord; out_ord = out_layout->to_mma() ? in_ord : out_ord; - Value *in_ld = i32(shape[in_ord[0]]); - Value *out_ld = i32(shape[out_ord[0]]); + int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); + int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); + int pad = std::max(in_vec, out_vec); + Value *in_ld = i32(shape[in_ord[0]] + pad); + Value *out_ld = i32(shape[out_ord[0]] + pad); for(int i = 0; i < n_reps[0]; i++) for(int j = 0; j < n_reps[1]; j++){ int max_ii, max_jj; @@ -2386,29 +2389,39 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ max_ii = in_ax[0].size()/n_reps[0]; max_jj = in_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj++){ + for(int jj = 0; jj < max_jj; jj+=in_vec){ // shared mem pointer indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); // stash value to shared mem - indices_t idxs = {in_ax[0][i*max_ii + ii], - in_ax[1][j*max_jj + jj]}; - store(bit_cast(vals_[in][idxs], ty), ptr); + Value* vals = UndefValue::get(vec_ty(ty, in_vec)); + for(int jjj = 0; jjj < in_vec; jjj++){ + indices_t idxs = {in_ax[0][i*max_ii + ii], + in_ax[1][j*max_jj + jj + jjj]}; + Value* val = bit_cast(vals_[in][idxs], ty); + vals = insert_elt(vals, val, jjj); + } + ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace())); + store(vals, ptr); } add_barrier(); max_ii = out_ax[0].size()/n_reps[0]; max_jj = out_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj++){ + for(int jj = 0; jj < max_jj; jj+=out_vec){ // shared mem pointer indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); + ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace())); // load value from shared rem - indices_t idxs = {out_ax[0][i*max_ii + ii], - out_ax[1][j*max_jj + jj]}; - vals_[out][idxs] = load(ptr); + Value* vals = load(ptr); + for(int jjj = 0; jjj < out_vec; jjj++){ + indices_t idxs = {out_ax[0][i*max_ii + ii], + out_ax[1][j*max_jj + jj + jjj]}; + vals_[out][idxs] = extract_elt(vals, jjj); + } } } From a50a47a85b9d55c196b4b046ccd90a2f43d7a2a1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 4 Mar 2022 01:53:31 -0800 Subject: [PATCH 074/215] [CODEGEN] Reverted some changes from previous PR; fixed vectorization characteristics of mma layout (#469) --- lib/codegen/analysis/layout.cc | 2 +- lib/codegen/analysis/swizzle.cc | 47 ++++++++++++++++----------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index fd0a7879b..5d30a2f45 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -213,7 +213,7 @@ mma_layout::mma_layout(size_t num_warps, else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 - contig_per_thread_ = {1, 1}; + contig_per_thread_ = {1, 2}; // rep_ = {2, 2, 1}; } order_ = {0, 1}; diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 414b0e0e5..5737f80a0 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -14,42 +14,41 @@ void swizzle::run(ir::module &) { max_phase_.clear(); for(auto &x: layouts_->get_all()){ - shared_layout* out_layout = dynamic_cast(x.second); - if(!out_layout) + shared_layout* layout = dynamic_cast(x.second); + if(!layout) continue; - scanline_layout* in_layout = dynamic_cast(out_layout->get_arg_layout()); - if(!in_layout) - continue; - - ir::value* mma_dot_a = out_layout->hmma_dot_a(); - ir::value* mma_dot_b = out_layout->hmma_dot_b(); + ir::value* mma_dot_a = layout->hmma_dot_a(); + ir::value* mma_dot_b = layout->hmma_dot_b(); if(!mma_dot_a && !mma_dot_b){ - per_phase_[out_layout] = 1; - max_phase_[out_layout] = 1; - vec_[out_layout] = 1; + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; continue; } - auto ord = out_layout->get_order(); - int dtsize = out_layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + auto ord = layout->get_order(); + scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); + if(!in_layout) + continue; + int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[out_layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[out_layout]; + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; if(mma_dot_a) - vec_[out_layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); + vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); else - vec_[out_layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); + vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { - if (!out_layout->allow_swizzle()) { - per_phase_[out_layout] = 1; - max_phase_[out_layout] = 1; - vec_[out_layout] = 1; + if (!layout->allow_swizzle()) { + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; } else { - per_phase_[out_layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[out_layout] = out_layout->get_mma_strided() / per_phase_[out_layout]; - vec_[out_layout] = out_layout->get_mma_vec(); + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; + vec_[layout] = layout->get_mma_vec(); } } } From 21f8a0646dd55222c90b3d86ebf5c28b36897cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fukan=20Tuna?= Date: Sat, 5 Mar 2022 11:50:37 +0300 Subject: [PATCH 075/215] [DOCS] Minor README.md (#470) Added binary distribution for quick installation --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 25c4aacb8..f4b6ef41c 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,21 @@ The foundations of this project are described in the following MAPL2019 publicat The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. +# Quick Installation + +You can install the latest stable release of Triton from pip: + +```bash +pip install triton +``` +Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7. + +And the latest nightly release: + +```bash +pip install -U --pre triton +``` + # Changelog Version 1.1 is out! New features include: From d4d8eaf6c08d824b0e098c6438f3cc2230279477 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 15 Mar 2022 12:20:51 -0700 Subject: [PATCH 076/215] [FRONTEND] improved caching mechanism (#474) Co-authored-by: Greg Brockman Co-authored-by: Christopher Hesse --- python/src/triton.cc | 9 +- python/test/unit/runtime/test_cache.py | 4 +- python/triton/code_gen.py | 177 ++++++++++++++----------- 3 files changed, 106 insertions(+), 84 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index c5c5b196f..9e53cc341 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -299,8 +299,12 @@ void init_triton_runtime(py::module &&m) { // get cached binary py::str key(cache_key); - if(!bin_cache.contains(key)) - add_to_cache(key, args, device, num_warps, num_stages); + py::bool_ noop = false; + if(!bin_cache.contains(key)) { + noop = add_to_cache(key, args, device, num_warps, num_stages); + } + if (noop) + return (py::object)py::none(); py::object bin = bin_cache[key]; // get grid @@ -529,6 +533,7 @@ void init_triton_codegen(py::module &&m) { return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ + py::gil_scoped_release allow_threads; if(backend == CUDA) return cu_load_binary(name, asm_map, n_shared_bytes, dev); if(backend == ROCM) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 48797b51a..8ac01bcc8 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -76,7 +76,7 @@ def reset_tmp_dir(): def test_reuse(): counter = 0 - def inc_counter(key, binary, repr): + def inc_counter(*args, **kwargs): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter @@ -91,7 +91,7 @@ def test_reuse(): def test_specialize(mode): counter = 0 - def inc_counter(key, binary, repr): + def inc_counter(*args, **kwargs): nonlocal counter counter += 1 JITFunction.cache_hook = inc_counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 894b3f1e3..3f170098b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -602,8 +602,19 @@ class Kernel: return 'str' raise NotImplementedError(f'could not compute type name for {obj}') + @staticmethod + def _to_python_ir(obj): + # convert torch.Tensor to Triton IR pointers + if hasattr(obj, 'data_ptr'): + name = Kernel._type_name(obj) + return 'ptr', name + # default path returns triton.ir.type directly + name = Kernel._type_name(obj) + return 'scalar', name + @staticmethod def _to_triton_ir(context, obj): + which, name = obj type_map = { 'I': _triton.ir.type.get_int32, 'L': _triton.ir.type.get_int64, @@ -625,12 +636,10 @@ class Kernel: 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers - if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj) + if which == 'ptr': elt_ty = type_map[name](context) return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - name = Kernel._type_name(obj) return type_map[name](context) @staticmethod @@ -648,36 +657,6 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): - # create IR module - context = _triton.ir.context() - # get just-in-time proto-type of kernel - fn_args = [arg for i, arg in enumerate(wargs) if i not in constants] - arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args] - ret_type = _triton.ir.type.get_void(context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) - # generate Triton-IR - # export symbols visible from self.fn into code-generator object - gscope = self.fn.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) - try: - generator.visit(self.fn.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.fn.src, node) from e - # Compile to machine code - if torch.version.hip is None: - backend = _triton.runtime.backend.CUDA - else: - backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - if shared_mem > max_shared_memory: - raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, name, asm, shared_mem, num_warps) - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -692,57 +671,12 @@ class Kernel: range_size = _triton.runtime.get_pointer_range_size(addr) attributes[i] = min(Kernel.pow2_divisor(addr), Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir and not os.path.exists(cache_dir): - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - if binary is None: - binary = self._compile( - *wargs, device=device_idx, attributes=attributes, - num_warps=num_warps, num_stages=num_stages, - constants=constants, - ) - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - if JITFunction.cache_hook is not None: - name = self.fn.__name__ - info = key.split('-')[-3:] - num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] - # make signature human-readable - arg_reprs = [] - for arg_name, arg_sig in zip(self.fn.arg_names, sig): - arg_reprs.append(f'{arg_name}: {arg_sig}') - # assemble the repr - arg_reprs = ", ".join(arg_reprs) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" - JITFunction.cache_hook(key=key, binary=binary, repr=repr) - - self.fn.bin_cache[key] = LoadedBinary(device_idx, binary) + arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] + return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): # handle arguments passed by name @@ -1027,6 +961,89 @@ class JITFunction: self.kernel = decorator(self.kernel) return self.kernel + def warmup(self, compile): + return self._warmup(**compile, is_manual_warmup=True) + + def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): + hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + + # create cache directory + cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + if cache_dir: + bin_cache_path = os.path.join(cache_dir, hashed_key) + bin_lock_path = bin_cache_path + ".lock" + else: + bin_cache_path = None + bin_lock_path = None + + binary = None + if bin_cache_path and os.path.exists(bin_cache_path): + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path, 'rb') as f: + binary = pickle.load(f)["binary"] + + compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) + if JITFunction.cache_hook is not None: + name = self.__name__ + info = key.split('-')[-3:] + num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] + # make signature human-readable + arg_reprs = [] + for arg_name, arg_sig in zip(self.arg_names, sig): + arg_reprs.append(f'{arg_name}: {arg_sig}') + # assemble the repr + arg_reprs = ", ".join(arg_reprs) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None) + if noop: + return True + + if binary is None: + binary = self._compile(**compile) + + if bin_cache_path: + assert bin_lock_path is not None + with FileLock(bin_lock_path): + with open(bin_cache_path + ".tmp", "wb") as f: + pickle.dump({"binary": binary, "key": key}, f) + os.rename(bin_cache_path + ".tmp", bin_cache_path) + + self.bin_cache[key] = LoadedBinary(device, binary) + return False + + def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): + # create IR module + context = _triton.ir.context() + # get just-in-time proto-type of kernel + arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] + ret_type = _triton.ir.type.get_void(context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) + # generate Triton-IR + # export symbols visible from self into code-generator object + gscope = self.__globals__ + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + try: + generator.visit(self.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.src, node) from e + # Compile to machine code + if torch.version.hip is None: + backend = _triton.runtime.backend.CUDA + else: + backend = _triton.runtime.backend.ROCM + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + if shared_mem > max_shared_memory: + raise OutOfResources(shared_mem, max_shared_memory, "shared memory") + return Binary(backend, name, asm, shared_mem, num_warps) + def __getitem__(self, grid): return Launcher(self._init_kernel(), grid) From 0dd2ec2e3a8dfb54092577c6418bc083684fca88 Mon Sep 17 00:00:00 2001 From: Yongjik Kim Date: Wed, 16 Mar 2022 14:38:56 -0700 Subject: [PATCH 077/215] [FRONTEND] Add an assert in case we get a CPU tensor. (#478) --- python/triton/code_gen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3f170098b..09254c967 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,6 +689,10 @@ class Kernel: # handle annotations for pos, _type in self.fn.annotations.items(): wargs[pos] = _type(wargs[pos]) + # check that tensors are on GPU. + for arg in wargs: + if hasattr(arg, 'data_ptr'): + assert arg.is_cuda, "All tensors must be on GPU!" # query device index and cuda stream device = torch.cuda.current_device() torch.cuda.set_device(device) From 539961072c277221e4e1dfda81e718daf9bbc1c7 Mon Sep 17 00:00:00 2001 From: daadaada Date: Thu, 17 Mar 2022 12:25:30 +0800 Subject: [PATCH 078/215] [FRONTEND] Semantic analysis refactor (#473) Moved dispatch.cc to semantic.py Integer signedness now moved from C++ to python Cleaner frontend type Co-authored-by: Phil Tillet --- include/triton/ir/builder.h | 31 +- include/triton/ir/context_impl.h | 1 - include/triton/ir/dispatch.h | 113 --- include/triton/ir/module.h | 33 +- include/triton/ir/type.h | 20 +- lib/ir/builder.cc | 53 +- lib/ir/context.cc | 18 +- lib/ir/dispatch.cc | 882 ----------------- lib/ir/instructions.cc | 4 +- lib/ir/module.cc | 140 --- lib/ir/type.cc | 14 - python/src/triton.cc | 392 +++++--- python/test/regression/test_performance.py | 2 +- python/test/unit/language/test_core.py | 23 - python/test/unit/runtime/test_cache.py | 28 + python/triton/__init__.py | 3 +- python/triton/code_gen.py | 314 ++++-- python/triton/language/core.py | 606 ++++++++---- python/triton/language/semantic.py | 1037 ++++++++++++++++++++ 19 files changed, 2044 insertions(+), 1670 deletions(-) delete mode 100644 include/triton/ir/dispatch.h delete mode 100644 lib/ir/dispatch.cc create mode 100644 python/triton/language/semantic.py diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 2b6bc6ab3..fe85be947 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -38,10 +38,8 @@ public: iterator get_insert_point() { return insert_point_;} // Constants value *get_int1(bool val); - value *get_int32(int32_t val); - value *get_int64(int64_t val); - value *get_uint32(uint32_t val); - value *get_uint64(uint64_t val); + value *get_int32(uint32_t val); + value *get_int64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -52,11 +50,9 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); - type *get_uint8_ty(); - type *get_uint16_ty(); - type *get_uint32_ty(); - type *get_uint64_ty(); + type *get_fp8_ty(); type *get_half_ty(); + type *get_bf16_ty(); type *get_float_ty(); type *get_double_ty(); // Insert @@ -74,7 +70,9 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); // Cast instructions + value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); + value* create_int_to_ptr(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty); @@ -93,11 +91,11 @@ public: value *create_frem(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sdiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); @@ -145,11 +143,22 @@ public: value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); + // Atomic instruction + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); + value *create_atomic_max(value *ptr, value *val, value *msk); + value *create_atomic_umax(value *ptr, value *val, value *msk); + value *create_atomic_min(value *ptr, value *val, value *msk); + value *create_atomic_umin(value *ptr, value *val, value *msk); + value *create_atomic_fadd(value *ptr, value *val, value *msk); + value *create_atomic_add(value *ptr, value *val, value *msk); + value *create_atomic_and(value *ptr, value *val, value *msk); + value *create_atomic_or(value *ptr, value *val, value *msk); + value *create_atomic_xor(value *ptr, value *val, value *msk); + value *create_atomic_xchg(value *ptr, value *val, value *msk); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); value *create_sin(value* arg); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 081ea249d..ef20af6b7 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -26,7 +26,6 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; - integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, std::unique_ptr> ptr_tys; // Block types diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h deleted file mode 100644 index ef14043dd..000000000 --- a/include/triton/ir/dispatch.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_DISPATCH_H_ -#define _TRITON_IR_DISPATCH_H_ - -#include "triton/ir/builder.h" -#include - -namespace triton{ -namespace ir{ - - -/*---------------------------------------------- - higher level functions that follow the likely - semantics of most expected frontends - ----------------------------------------------*/ - -struct semantic_error: public std::runtime_error { - semantic_error(const std::string& msg): - std::runtime_error(msg) { } -}; - -struct dispatch{ - typedef ir::type::block_shapes_t shape_t; - - - // programming model - static ir::value *program_id(int axis, ir::builder *builder); - static ir::value *num_programs(int axis, ir::builder *builder); - - // binary operators - static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); - static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); - - // unary operators - static ir::value *plus(ir::value *input, ir::builder *builder); - static ir::value *minus(ir::value *input, ir::builder *builder); - static ir::value *invert(ir::value *input, ir::builder *builder); - - // comparison operators - static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); - - // block creation - static ir::value* arange(int start, int end, ir::builder *builder); - static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); - - - // casting ops - static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); - static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); - static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); - static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); - static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); - static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); - - // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, - const std::string& eviction_policy, int is_volatile, ir::builder *builder); - static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); - static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); - static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - - // linear algebra - static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); - - // indexing - static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); - - // reduction - static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); - - // math - static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); - static ir::value *exp(ir::value *x, ir::builder *builder); - static ir::value *log(ir::value *x, ir::builder *builder); - static ir::value *cos(ir::value *x, ir::builder *builder); - static ir::value *sin(ir::value *x, ir::builder *builder); - static ir::value *sqrt(ir::value *x, ir::builder *builder); - - // internal (debug/optimization) - static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); - static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); - static ir::value *debug_barrier(ir::builder *builder); -}; - -} -} - -#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 30881fd49..ea64dfc6e 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -57,26 +57,10 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, builder& builder); - builder& get_builder(); - // Setters - void set_value(const std::string& name, basic_block* block, value *x); - void set_value(const std::string& name, value* x); - void set_const(const std::string& name); - void set_continue_fn(std::function fn); - // Getters - const std::map& get_values() { return values_; } - const std::map& get_types() { return types_; } - void set_values(const std::map& values) { values_ = values; } - void set_types(const std::map& types) { types_ = types; } + module(const std::string &name, builder &builder): name_(name), builder_(builder) {} + builder &get_builder() { return builder_; }; + const std::string& get_name() { return name_; }; - value *get_value(const std::string& name, basic_block* block); - value *get_value(const std::string& name); - void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } - const std::string& get_name(); - std::function get_continue_fn(); - // Seal block -- no more predecessors will be added - void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } @@ -89,21 +73,14 @@ public: const std::map& globals() const { return globals_; } // Metadata void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - + const std::map &get_metadatas() const { return metadatas_; } void print(std::ostream &os); private: std::string name_; - builder& builder_; - std::map values_; - std::map types_; - std::set const_; - std::set sealed_blocks_; - std::map> incomplete_phis_; + builder &builder_; functions_list_t functions_; symbols_map_t symbols_; - std::function continue_fn_; - std::map current_phi_; std::vector allocs_; std::map globals_; std::map metadatas_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 47c9b5f85..b1ef1ad22 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -16,8 +16,6 @@ class value; class integer_type; class constant_int; -enum class signedness { SIGNED, UNSIGNED }; - /* Type */ class type { public: @@ -61,8 +59,6 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; - signedness get_integer_signedness() const; - bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -85,9 +81,6 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_integer_ty(unsigned bitwidth, signedness sn) { - return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; - } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -115,10 +108,6 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); - static integer_type *get_uint8_ty(context &ctx); - static integer_type *get_uint16_ty(context &ctx); - static integer_type *get_uint32_ty(context &ctx); - static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -145,7 +134,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -168,21 +157,18 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth, signedness sn) - : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } + integer_type(context &ctx, unsigned bitwidth) + : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} public: // accessors unsigned get_bitwidth() const { return bitwidth_; } - signedness get_signedness() const { return signedness_; } - // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; - signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index fff73e665..9b8a2a45e 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -48,18 +48,12 @@ void builder::set_insert_point(basic_block *block){ value *builder::get_int1(bool val) { return constant_int::get(type::get_int1_ty(ctx_), val); } -value *builder::get_int32(int32_t val) +value *builder::get_int32(uint32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} -value *builder::get_uint32(uint32_t val) -{ return constant_int::get(type::get_uint32_ty(ctx_), val);} - -value *builder::get_int64(int64_t val) +value *builder::get_int64(uint64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} -value *builder::get_uint64(uint64_t val) -{ return constant_int::get(type::get_uint64_ty(ctx_), val);} - value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -90,21 +84,15 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } -type *builder::get_uint8_ty() -{ return type::get_uint8_ty(ctx_); } - -type *builder::get_uint16_ty() -{ return type::get_uint16_ty(ctx_); } - -type *builder::get_uint32_ty() -{ return type::get_uint32_ty(ctx_); } - -type *builder::get_uint64_ty() -{ return type::get_uint64_ty(ctx_); } +type *builder::get_fp8_ty() +{ return type::get_fp8_ty(ctx_); } type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } +type *builder::get_bf16_ty() +{ return type::get_bf16_ty(ctx_); } + type *builder::get_float_ty() { return type::get_fp32_ty(ctx_); } @@ -139,6 +127,8 @@ value *builder::create_ret_void() { return create_cast(OPCODE, src, dst_ty);\ } +DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) +DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) @@ -331,6 +321,28 @@ value *builder::create_downcast(value *arg) { return insert(downcast_inst::create(arg)); } +// + +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); +} + +#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ + value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ + return create_atomic_rmw(OPCODE, ptr, val, mask);\ + } + +DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) +DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) +DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) +DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) +DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) +DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) +DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) +DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) +DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) +DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -347,9 +359,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 90b109b9b..0fc65ddc2 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,18 +19,12 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1, signedness::SIGNED), - int8_ty(ctx, 8, signedness::SIGNED), - int16_ty(ctx, 16, signedness::SIGNED), - int32_ty(ctx, 32, signedness::SIGNED), - int64_ty(ctx, 64, signedness::SIGNED), - int128_ty(ctx, 128, signedness::SIGNED), - uint8_ty(ctx, 8, signedness::UNSIGNED), - uint16_ty(ctx, 16, signedness::UNSIGNED), - uint32_ty(ctx, 32, signedness::UNSIGNED), - uint64_ty(ctx, 64, signedness::UNSIGNED){ - -} + int1_ty(ctx, 1), + int8_ty(ctx, 8), + int16_ty(ctx, 16), + int32_ty(ctx, 32), + int64_ty(ctx, 64), + int128_ty(ctx, 128) {} //===----------------------------------------------------------------------===// // context diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc deleted file mode 100644 index 664fbb983..000000000 --- a/lib/ir/dispatch.cc +++ /dev/null @@ -1,882 +0,0 @@ -#include "triton/ir/dispatch.h" - -namespace triton { -namespace ir { - - -[[ noreturn ]] void throw_unreachable(std::string key) { - throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " - "This is likely a bug on our side."); -} - -//===----------------------------------------------------------------------===// -// Programming Model -//===----------------------------------------------------------------------===// - -ir::value *dispatch::program_id(int axis, ir::builder *builder) { - return builder->create_get_program_id(axis); -} - -ir::value *dispatch::num_programs(int axis, ir::builder *builder) { - return builder->create_get_num_programs(axis); -} - -//===----------------------------------------------------------------------===// -// Implicit Casting Utilities -//===----------------------------------------------------------------------===// - -ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ - int a_rank = a_ty->get_integer_bitwidth(); - int b_rank = b_ty->get_integer_bitwidth(); - auto a_sn = a_ty->get_integer_signedness(); - auto b_sn = b_ty->get_integer_signedness(); - // Rules for signedness taken from "Usual arithmetic conversions" on - // https://en.cppreference.com/w/c/language/conversion. - if (a_sn == b_sn) { - return a_rank > b_rank ? a_ty : b_ty; - } else if (a_sn == signedness::UNSIGNED) { - return a_rank >= b_rank ? a_ty : b_ty; - } else if (b_sn == signedness::UNSIGNED) { - return b_rank >= a_rank ? b_ty : a_ty; - } else { - throw_unreachable("integer_promote"); - } -} - -enum class DivOrMod { NO, YES }; - -ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { - context &ctx = a_ty->get_context(); - // 1) if one operand is double, the other is implicitly - // converted to double - if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) - return type::get_fp64_ty(ctx); - // 2) if one operand is float, the other is implicitly - // converted to float - if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) - return type::get_fp32_ty(ctx); - // 3 ) if one operand is half, the other is implicitly converted to half - // unless we're doing / or %, which do not exist natively in PTX for fp16. - if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { - if (div_or_mod == DivOrMod::YES) { - return type::get_fp32_ty(ctx); - } else { - return type::get_fp16_ty(ctx); - } - } - if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) - throw_unreachable("computation_type"); - // 4 ) both operands are integer and undergo - // integer promotion - if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { - throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - return integer_promote(a_ty, b_ty); -} - -//===----------------------------------------------------------------------===// -// Binary Operators -//===----------------------------------------------------------------------===// - -void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { - throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); -} - -void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ - - if(type_a->is_pointer_ty()){ - if(!allow_ptr_a) - throw_incompatible_types(type_a, type_b); - // T* + U* with T != U - if(type_b->is_pointer_ty() && (type_a != type_b)) - throw_incompatible_types(type_a, type_b); - // T* + float - if(type_b->is_floating_point_ty()) - throw_incompatible_types(type_a, type_b); - } -} - -void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, - bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, - bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { - // implicit broadcasting - std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); - // implicit typecasting - ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); - ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); - check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); - check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); - if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { - ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); - lhs = dispatch::cast(lhs, ret_sca_ty, builder); - rhs = dispatch::cast(rhs, ret_sca_ty, builder); - } -} - -ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, true); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // offset + ptr - // ptr + offset - if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) - std::swap(input, other); - if (input_scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {other}); - // float + float - else if (input_scalar_ty->is_floating_point_ty()) - return builder->create_fadd(input, other); - // int + int - else if (input_scalar_ty->is_integer_ty()) - return builder->create_add(input, other); - throw_unreachable("add"); -} - -ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, false); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // ptr - offset - if (scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {dispatch::minus(other, builder)}); - // float + float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fsub(input, other); - // int + int - else if (scalar_ty->is_integer_ty()) - return builder->create_sub(input, other); - throw_unreachable("sub"); -} - -ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float * float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fmul(input, other); - // int * int - else if (scalar_ty->is_integer_ty()) - return builder->create_mul(input, other); - throw_unreachable("mul"); -} - -ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float / int - if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) - other = cast(other, input_scalar_ty, builder); - // int / float - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) - input = cast(input, other_scalar_ty, builder); - // int / int (cast to float32) - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - input = cast(input, builder->get_float_ty(), builder); - other = cast(other, builder->get_float_ty(), builder); - } - // float / float (cast to highest exponent type) - else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ - if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) - other = cast(other, input_scalar_ty, builder); - else - input = cast(input, other_scalar_ty, builder); - } - // unreachable - else - throw_unreachable("div"); - return builder->create_fdiv(input, other); -} - -ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); - input = dispatch::cast(input, ret_ty, builder); - other = dispatch::cast(other, ret_ty, builder); - if (ret_ty->is_integer_signed()) { - return builder->create_sdiv(input, other); - } else { - return builder->create_udiv(input, other); - } - } - throw_unreachable("floordiv"); -} - -ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) - throw semantic_error("both operands of fdiv must have floating point scalar type"); - binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); - ir::value* ret = builder->create_fdiv(input, other); - if(ir::binary_operator* binop = dynamic_cast(ret)) - binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); - return ret; -} - -ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float % int - if (scalar_ty->is_floating_point_ty()) - return builder->create_frem(input, other); - // int % int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { - throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - if (scalar_ty->is_integer_signed()) { - return builder->create_srem(input, other); - } else { - return builder->create_urem(input, other); - } - } - throw_unreachable("mod"); -} - - -void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, false); - ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); - ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); - if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) - throw_incompatible_types(input_sca_ty, other_sca_ty); - ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); - if (ret_sca_ty != input_sca_ty) - input = dispatch::cast(input, ret_sca_ty, builder); - if (ret_sca_ty != other_sca_ty) - other = dispatch::cast(other, ret_sca_ty, builder); -} - -ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_and(input, other); -} - -ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_or(input, other); -} - - -ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_xor(input, other); -} - - -ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_lshr(input, other); -} - - -ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_shl(input, other); -} - -//===----------------------------------------------------------------------===// -// Unary Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::plus(ir::value *input, ir::builder *) { - return input; -} - -ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty()) - throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); - ir::value *_0 = ir::constant::get_null_value(input_sca_ty); - return dispatch::sub(_0, input, builder); -} - -ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) - throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); - ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); - return dispatch::xor_(input, _1, builder); -} - - -//===----------------------------------------------------------------------===// -// Comparison Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float > float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGT(input, other); - // int > int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGT(input, other); - } else { - return builder->create_icmpUGT(input, other); - } - } - throw_unreachable("greater_than"); -} - -ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float >= float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGE(input, other); - // int >= int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGE(input, other); - } else { - return builder->create_icmpUGE(input, other); - } - } - throw_unreachable("greater_equal"); -} - -ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLT(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLT(input, other); - } else { - return builder->create_icmpULT(input, other); - } - } - throw_unreachable("less_than"); -} - -ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLE(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLE(input, other); - } else { - return builder->create_icmpULE(input, other); - } - } - throw_unreachable("less_equal"); -} - -ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOEQ(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpEQ(input, other); - throw_unreachable("equal"); -} - -ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpUNE(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpNE(input, other); - throw_unreachable("equal"); -} - -//===----------------------------------------------------------------------===// -// Block Creation -//===----------------------------------------------------------------------===// - -ir::value* dispatch::arange(int start, int end, ir::builder *builder) { - return builder->get_range(start, end); -} - -ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { - ir::value *_0 = ir::constant::get_null_value(dtype); - return builder->create_splat(_0, shape); -} - -//===----------------------------------------------------------------------===// -// Shape Manipulation -//===----------------------------------------------------------------------===// - - -ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { - unsigned numel = 1; - for(unsigned s: dst_shape) numel *= s; - if(input->get_type()->get_tile_num_elements() != numel) - throw semantic_error("cannot reshape block of different shape"); - return builder->create_reshape(input, dst_shape); -} - -ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { - return builder->create_cat(lhs, rhs); -} - -ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { - if (!input->get_type()->is_block_ty()) - return builder->create_splat(input, shape); - auto src_shape = input->get_type()->get_block_shapes(); - if (src_shape.size() != shape.size()) - throw std::runtime_error("Cannot broadcast"); - if(shape == src_shape) - return input; - return builder->create_broadcast(input, shape); -} - -std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - - // make_shape_compatible(block, scalar) - if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) - rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); - // make_shape_compatible(scalar, block) - else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) - lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); - // make_shape_compatible(block, block) - else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { - auto lhs_shape = lhs_ty->get_block_shapes(); - auto rhs_shape = rhs_ty->get_block_shapes(); - if (lhs_shape.size() != rhs_shape.size()) - throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); - ir::type::block_shapes_t ret_shape; - for (size_t i = 0; i < lhs_shape.size(); ++i) { - unsigned left = lhs_shape[i]; - unsigned right = rhs_shape[i]; - if (left == 1) - ret_shape.push_back(right); - else if (right == 1) - ret_shape.push_back(left); - else if (left == right) - ret_shape.push_back(left); - else - throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + - ": " + std::to_string(left) + " and " + std::to_string(right)); - } - if (lhs_shape != ret_shape) - lhs = builder->create_broadcast(lhs, ret_shape); - if (rhs_shape != ret_shape) - rhs = builder->create_broadcast(rhs, ret_shape); - } - return std::make_tuple(lhs, rhs); -} - -ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty()) - return cast(input, dst_ty, builder); - // Bitcast - int src_bits = src_sca_ty->get_primitive_size_in_bits(); - int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); - if( src_bits!= dst_bits) - throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + - "to data-type of size " + std::to_string(dst_bits)); - return builder->create_cast(ir::BitCast, input, dst_ty); -} - -ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - // - if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || - (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ - return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); - } - // FP Truncation - bool truncate_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); - if (truncate_fp) - return builder->create_fp_trunc(input, dst_ty); - // FP Extension - bool ext_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); - if (ext_fp) - return builder->create_fp_ext(input, dst_ty); - // Int cast - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || - src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { - bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); - return builder->create_int_cast(input, dst_ty, sign_extend); - } - // Float -> Int - if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ - if(dst_sca_ty->is_bool_ty()) - return builder->create_fp_to_ui(input, dst_ty); - else - return builder->create_fp_to_si(input, dst_ty); - } - // int -> Float - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ - if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) - return builder->create_ui_to_fp(input, dst_ty); - else - return builder->create_si_to_fp(input, dst_ty); - } - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ - int bitwidth = dst_sca_ty->get_integer_bitwidth(); - if(bitwidth == 64) - return builder->create_cast(ir::PtrToInt, input, dst_ty); - if(bitwidth == 1) - return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), - builder->get_int64(0), - builder); - } - if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::IntToPtr, input, dst_ty); - // Ptr -> Ptr - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::BitCast, input, dst_ty); - // * -> Bool - if (dst_sca_ty->is_bool_ty()) { - if (src_sca_ty->is_pointer_ty()) - input = cast(input, builder->get_int64_ty(), builder); - ir::value *other = builder->get_int64(0); - if (src_ty->is_bool_ty()) - other = builder->create_splat(other, src_ty->get_block_shapes()); - return builder->create_icmpNE(input, other); - } - throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); -} - -//===----------------------------------------------------------------------===// -// Memory Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - if(other) - other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); - } - if(other) - other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cache modifier - load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default - if (!cache_modifier.empty()) { - if (cache_modifier == ".ca") - cache = load_inst::CA; - else if (cache_modifier == ".cg") - cache = load_inst::CG; - else - throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); - } - // eviction policy - load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default - if(!eviction_policy.empty()){ - if (eviction_policy == "evict_last") - eviction = load_inst::EVICT_LAST; - else if(eviction_policy == "evict_first") - eviction = load_inst::EVICT_FIRST; - else - throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); - } - - - if (!mask && !other) - return builder->create_load(ptr, cache, eviction, is_volatile); - if (!mask) - throw std::runtime_error("`other` cannot be provided without `mask`"); - auto shape = ptr->get_type()->get_block_shapes(); - if(!other){ - other = ir::undef_value::get(elt_ty); - if(ptr->get_type()->is_block_ty()) - other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); - } - return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); -} - -ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()) - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cast to target data-type - val = dispatch::cast(val, elt_ty, builder); - if (!mask) - return builder->create_store(ptr, val); - if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) - throw semantic_error("Mask must have boolean scalar type"); - return builder->create_masked_store(ptr, val, mask); -} - -ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ - return builder->create_atomic_cas(ptr, cmp, val); -} - -void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask){ - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - } - if(val){ - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - } - } - val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - if(!mask){ - mask = builder->get_int1(true); - if(ptr->get_type()->is_block_ty()) - mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); - } -} - -ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); - } - } - // for float - // return atomic_smax(i_ptr, i_val) if val >= 0 - // return atomic_umin(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_min for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); - } - } - // for float - // return atomic_smin(i_ptr, i_val) if val >= 0 - // return atomic_umax(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; - return builder->create_atomic_rmw(op, ptr, val, mask); -} - -ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); -} - -ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); -} - -ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); -} - -ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); -} - -//===----------------------------------------------------------------------===// -// Linear Algebra -//===----------------------------------------------------------------------===// - -ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { - ir::value *_0 = nullptr; - if (lhs->get_type()->is_int_or_tileint_ty()) - _0 = builder->get_int32(0); - else - _0 = builder->get_float32(0); - unsigned M = lhs->get_type()->get_block_shapes()[0]; - unsigned N = rhs->get_type()->get_block_shapes()[1]; - _0 = builder->create_splat(_0, {M, N}); - bool _allow_tf32 = allow_tf32->get_value() != 0; - return builder->create_dot(lhs, rhs, _0, _allow_tf32); -} - - -//===----------------------------------------------------------------------===// -// Indexing -//===----------------------------------------------------------------------===// - -ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ - condition = dispatch::cast(condition, builder->get_int1_ty(), builder); - if(condition->get_type()->is_block_ty()){ - x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); - y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); - } - ir::type* x_ty = x->get_type()->get_scalar_ty(); - ir::type* y_ty = y->get_type()->get_scalar_ty(); - ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); - x = dispatch::cast(x, ty, builder); - y = dispatch::cast(y, ty, builder); - return builder->create_select(condition, x, y); -} - - -//===----------------------------------------------------------------------===// -// Reductions -//===----------------------------------------------------------------------===// - -ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, - ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // input is extended to 32-bits if necessary - // this increases numerical accuracy and can be done pretty much for free - // on GPUs - if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) - input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); - if (scalar_ty->is_floating_point_ty()) - return builder->create_reduce(input, FLOAT_OP, axis); - else if (scalar_ty->is_integer_ty()) - return builder->create_reduce(input, INT_OP, axis); - throw_unreachable(name); -} - -ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); -} - -ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); -} - -ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); -} - -ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - if (!scalar_ty->is_integer_ty()) - throw semantic_error("xor_sum only supported for integers"); - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); -} - - -//===----------------------------------------------------------------------===// -// Math -//===----------------------------------------------------------------------===// - -ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { - binary_op_type_checking(x, y, builder); - return builder->insert(umulhi_inst::create(x, y)); -} - -ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { - return builder->create_exp(x); -} - -ir::value *dispatch::log(ir::value *x, ir::builder *builder) { - return builder->create_log(x); -} - -ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { - return builder->create_cos(x); -} - -ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { - return builder->create_sin(x); -} - -ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { - return builder->create_sqrt(x); -} - - -// - -ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("multiple_of"); - i->set_metadata(ir::metadata::multiple_of, value); - return i; -} - -ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("max_contiguous"); - i->set_metadata(ir::metadata::max_contiguous, value); - return i; -} - -ir::value *dispatch::debug_barrier(ir::builder *builder) { - return builder->create_barrier(); -} - - -} -} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index c225b315f..39bd945bc 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -312,8 +312,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); + (arg_bits > dst_bits ? cast_op_t::Trunc : + (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); return create(op, arg, ty, name, next); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 33b39de3a..a37d3048f 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,146 +9,6 @@ namespace triton{ namespace ir{ -/* Module */ -module::module(const std::string &name, builder &builder) - : name_(name), builder_(builder) { - sealed_blocks_.insert(nullptr); -} - -ir::builder& module::get_builder() { - return builder_; -} - -void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ - values_[val_key_t{name, block}] = value; - auto it = metadatas_.find(name); - if(auto *x = dynamic_cast(value)) - if(it != metadatas_.end()){ - x->set_metadata(it->second.first, it->second.second); - } -// value->set_name(name); -} - -void module::set_value(const std::string& name, ir::value *value){ - return set_value(name, builder_.get_insert_block(), value); -} - -void module::set_const(const std::string& name){ - const_.insert(name); -} - -void module::set_continue_fn(std::function fn) { - continue_fn_ = fn; -} - -std::function module::get_continue_fn() { - return continue_fn_; -} - -ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ - basic_block::iterator insert = block->get_first_non_phi(); - if(insert != block->end()){ - builder_.set_insert_point(insert); - } - ir::phi_node *res = builder_.create_phi(ty, num_values); - if(insert != block->end()) - builder_.set_insert_point(block); - return res; -} - -ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ - // find non-self references - std::set non_self_ref; - std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), - [phi](ir::value* op){ return op != phi && op; }); - // non-trivial - if(non_self_ref.size() != 1) - return phi; - // unique value or self-reference - ir::value *same = *non_self_ref.begin(); - assert(same != nullptr); - phi->replace_all_uses_with(same); - phi->erase_from_parent(); - std::set users = phi->get_users(); - for(ir::user* u: users) - if(auto *uphi = dynamic_cast(u)) - if(uphi != phi) - try_remove_trivial_phis(uphi); - return same; -} - - -ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ - // already initialized - if(phi->get_num_operands()) - return phi; - ir::basic_block *block = phi->get_parent(); - for(ir::basic_block *pred: block->get_predecessors()){ - ir::value *value = get_value(name, pred); - phi->add_incoming(value, pred); - } - return phi; -} - -ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { - ir::value *result; - bool is_const = const_.find(name) != const_.end(); - auto &preds = block->get_predecessors(); - ir::type *ty = types_.at(name); - if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ - incomplete_phis_[block][name] = make_phi(ty, 1, block); - result = (ir::value*)incomplete_phis_[block][name]; - } - else if(preds.size() <= 1){ - bool has_pred = preds.size(); - result = get_value(name, has_pred?preds.front():nullptr); - } - else{ - ir::phi_node* phi = make_phi(ty, 1, block); - set_value(name, block, phi); - result = add_phi_operands(name, phi); - if(auto *phi = dynamic_cast(result)) - result = try_remove_trivial_phis(phi); - } - if(auto *phi = dynamic_cast(result)){ - result = try_remove_trivial_phis(phi); - } - set_value(name, block, result); - return result; -} - -ir::value *module::get_value(const std::string& name, ir::basic_block *block) { - ir::basic_block* save_block = builder_.get_insert_block(); - ir::basic_block::iterator save_pt = builder_.get_insert_point(); - val_key_t key(name, block); - if(values_.find(key) != values_.end()){ - return values_.at(key); - } - ir::value *result = get_value_recursive(name, block); - builder_.set_insert_point(save_block); - if(save_pt != save_block->end()) - builder_.set_insert_point(save_pt); - return result; -} - -ir::value *module::get_value(const std::string& name) { - return get_value(name, builder_.get_insert_block()); -} - -const std::string& module::get_name() { - return name_; -} - -void module::seal_block(ir::basic_block *block){ - for(auto &x: incomplete_phis_[block]){ - add_phi_operands(x.first, x.second); - if(get_value(x.first) == x.second) - set_value(x.first, try_remove_trivial_phis(x.second)); - } - sealed_blocks_.insert(block); - incomplete_phis_[block].clear(); -} - /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 7e4e4e5d7..056ae99e6 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,16 +36,6 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } -signedness type::get_integer_signedness() const -{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } - -bool type::is_integer_signed() const { - if (id_ != IntegerTyID) { - throw std::logic_error("type is " + repr() + ", not integer"); - } - return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; -} - unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -145,10 +135,6 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } -integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } -integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } -integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } -integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 9e53cc341..b66761ec3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,7 +3,6 @@ #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" -#include "triton/ir/dispatch.h" #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -12,10 +11,12 @@ #include #include #include +#include #include #include "Python.h" #include #include +#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -541,84 +542,6 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); } -/*****************************************************************************/ -/* User-facing language features */ -/*****************************************************************************/ - -void init_triton_frontend(py::module &&m) { - using ret = py::return_value_policy; - - // programming model - m.def("program_id", &ir::dispatch::program_id, ret::reference); - m.def("num_programs", &ir::dispatch::num_programs, ret::reference); - // binary - m.def("add", &ir::dispatch::add, ret::reference); - m.def("sub", &ir::dispatch::sub, ret::reference); - m.def("mul", &ir::dispatch::mul, ret::reference); - m.def("truediv", &ir::dispatch::truediv, ret::reference); - m.def("floordiv", &ir::dispatch::floordiv, ret::reference); - m.def("fdiv", &ir::dispatch::fdiv, ret::reference); - m.def("mod", &ir::dispatch::mod, ret::reference); - m.def("and_", &ir::dispatch::and_, ret::reference); - m.def("or_", &ir::dispatch::or_, ret::reference); - m.def("xor_", &ir::dispatch::xor_, ret::reference); - m.def("lshr", &ir::dispatch::lshr, ret::reference); - m.def("shl", &ir::dispatch::shl, ret::reference); - // unary - m.def("plus", &ir::dispatch::plus, ret::reference); - m.def("minus", &ir::dispatch::minus, ret::reference); - m.def("invert", &ir::dispatch::invert, ret::reference); - // comparison - m.def("greater_than", &ir::dispatch::greater_than, ret::reference); - m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); - m.def("less_than", &ir::dispatch::less_than, ret::reference); - m.def("less_equal", &ir::dispatch::less_equal, ret::reference); - m.def("equal", &ir::dispatch::equal, ret::reference); - m.def("not_equal", &ir::dispatch::not_equal, ret::reference); - // block creation - m.def("arange", &ir::dispatch::arange, ret::reference); - m.def("zeros", &ir::dispatch::zeros, ret::reference); - // type manipuatation - m.def("cat", &ir::dispatch::cat, ret::reference); - m.def("reshape", &ir::dispatch::reshape, ret::reference); - typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); - typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); - m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("bitcast", &ir::dispatch::bitcast, ret::reference); - m.def("cast", &ir::dispatch::cast, ret::reference); - // memory - m.def("load", &ir::dispatch::load, ret::reference); - m.def("store", &ir::dispatch::store, ret::reference); - m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); - m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); - m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); - m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); - m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); - m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); - m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); - m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); - // linear algebra - m.def("dot", &ir::dispatch::dot, ret::reference); - // indexing - m.def("where", &ir::dispatch::where, ret::reference); - // reduction - m.def("min", &ir::dispatch::min, ret::reference); - m.def("max", &ir::dispatch::max, ret::reference); - m.def("sum", &ir::dispatch::sum, ret::reference); - m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); - // math - m.def("umulhi", &ir::dispatch::umulhi, ret::reference); - m.def("exp", &ir::dispatch::exp, ret::reference); - m.def("log", &ir::dispatch::log, ret::reference); - m.def("cos", &ir::dispatch::cos, ret::reference); - m.def("sin", &ir::dispatch::sin, ret::reference); - m.def("sqrt", &ir::dispatch::sqrt, ret::reference); - // internal (debugging only) - m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); - m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); - m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); -} /*****************************************************************************/ /* Python bindings for triton::ir */ @@ -628,16 +551,86 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; + py::enum_(m, "CACHE_MODIFIER") + .value("NONE", ir::load_inst::NONE) + .value("CA", ir::load_inst::CA) + .value("CG", ir::load_inst::CG) + .export_values(); + + py::enum_(m, "EVICTION_POLICY") + .value("NORMAL", ir::load_inst::NORMAL) + .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) + .value("EVICT_LAST", ir::load_inst::EVICT_LAST) + .export_values(); + + py::enum_(m, "REDUCE_OP") + .value("ADD", ir::reduce_inst::ADD) + .value("FADD", ir::reduce_inst::FADD) + .value("MIN", ir::reduce_inst::MIN) + .value("MAX", ir::reduce_inst::MAX) + .value("FMIN", ir::reduce_inst::FMIN) + .value("FMAX", ir::reduce_inst::FMAX) + .value("XOR", ir::reduce_inst::XOR); + + py::enum_(m, "ATOMIC_OP") + .value("ADD", ir::atomic_rmw_op_t::Add) + .value("FADD", ir::atomic_rmw_op_t::FAdd) + .value("AND", ir::atomic_rmw_op_t::And) + .value("OR", ir::atomic_rmw_op_t::Or) + .value("XOR", ir::atomic_rmw_op_t::Xor) + .value("XCHG", ir::atomic_rmw_op_t::Xchg) + .value("MAX", ir::atomic_rmw_op_t::Max) + .value("MIN", ir::atomic_rmw_op_t::Min) + .value("UMIN", ir::atomic_rmw_op_t::UMin) + .value("UMAX", ir::atomic_rmw_op_t::UMax); + py::class_(m, "context") .def(py::init<>()); - auto value = py::class_(m, "value"); - value.def_property("name", &ir::value::get_name, &ir::value::set_name); - value.def_property_readonly("type", &ir::value::get_type); + py::class_(m, "value") + .def("multiple_of", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::multiple_of, val); + } else + throw std::runtime_error("multiple_of"); + }) + .def("max_contiguous", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::max_contiguous, val); + } else + throw std::runtime_error("max_contiguous"); + }) + .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { + if (auto *instr = dynamic_cast(self)) + instr->set_fdiv_ieee_rounding(val); + else + throw std::runtime_error("set_fdiv_ieee_rounding"); + }) + .def("is_phi", [](ir::value *self) { + if (auto *pn = dynamic_cast(self)) + return true; + return false; + }) + .def("ops", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) { + return instr->ops(); + } + throw std::runtime_error("cannot use ops()"); + }) + .def("replace_all_uses_with", &ir::value::replace_all_uses_with) + .def("erase_from_parent", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) + return instr->erase_from_parent(); + throw std::runtime_error("cannot use erase_from_parent"); + }) + .def_property("name", &ir::value::get_name, &ir::value::set_name) + .def_property_readonly("type", &ir::value::get_type); py::class_(m, "user"); - py::class_(m, "constant"); + py::class_(m, "constant") + .def("get_null_value", &ir::constant::get_null_value, ret::reference) + .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); py::class_(m, "undef") .def("get", &ir::undef_value::get, ret::reference); @@ -648,16 +641,17 @@ void init_triton_ir(py::module &&m) { .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value); + .def_property_readonly("value", &ir::constant_fp::get_value) + .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); - py::class_(m, "instruction"); - py::class_(m, "phi_node"); + py::class_(m, "instruction") + .def("get_parent", [](ir::instruction *self) { + return self->get_parent(); + }, ret::reference); + py::class_(m, "phi_node") + .def("add_incoming", &ir::phi_node::add_incoming); py::class_(m, "type") - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) @@ -672,34 +666,38 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) - .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) - .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) - .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) + .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) + .def("get_block_shapes", &ir::type::get_block_shapes) + + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) .def("is_void", &ir::type::is_void_ty) + .def("is_bool", &ir::type::is_bool_ty) .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_fp16_ty) .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) - .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) - .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) - .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) - .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference); + .def_property_readonly("context", &ir::type::get_context, ret::reference) + .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) + .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) + .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); py::class_(m, "function_type"); py::class_(m, "integer_type"); @@ -709,16 +707,15 @@ void init_triton_ir(py::module &&m) { py::class_(m, "module") .def(py::init()) - .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) - .def("seal_block", &ir::module::seal_block) - .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) - .def("set_type", &ir::module::set_type) - .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) - .def("get_values", &ir::module::get_values, ret::reference) - .def("set_values", &ir::module::set_values) - .def("get_types", &ir::module::get_types, ret::reference) - .def("set_types", &ir::module::set_types) - .def_property_readonly("builder", &ir::module::get_builder, ret::reference); + .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { + const auto metadatas = self->get_metadatas(); + auto it = metadatas.find(name); + if (it != metadatas.end()) + if (auto *instr = dynamic_cast(value)) { + instr->set_metadata(it->second.first, it->second.second); + } + }) + .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); using eattr = ir::attribute_kind_t; py::enum_(m, "attribute_kind") @@ -742,6 +739,13 @@ void init_triton_ir(py::module &&m) { py::class_(m, "basic_block") .def("create", &ir::basic_block::create, ret::reference) + .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) + .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { + ir::basic_block::iterator it = self->get_first_non_phi(); + if (it == self->end()) + return nullptr; + return *it; + }, ret::reference) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "builder", py::dynamic_attr()) @@ -752,17 +756,162 @@ void init_triton_ir(py::module &&m) { .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) + // insertion block/point, insert points are represented as (*bb, *instr) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) - // constants + .def("get_insert_point", [](ir::builder *self) { + ir::basic_block *bb = self->get_insert_block(); + ir::basic_block::iterator it = self->get_insert_point(); + ir::instruction *instr = it == bb->end() ? nullptr : *it; + return std::make_pair(bb, instr); + }, ret::reference) + .def("set_insert_point", [](ir::builder *self, std::pair pt) { + ir::basic_block *bb = pt.first; + ir::instruction *instr = pt.second; + if (instr) { + if (bb != instr->get_parent()) + throw std::runtime_error("invalid insertion point, instr not in bb"); + self->set_insert_point(instr); + } else { + assert(bb); + self->set_insert_point(bb); + } + }) + // Constants .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", &ir::builder::get_int32, ret::reference) - .def("get_int64", &ir::builder::get_int64, ret::reference) - .def("get_uint32", &ir::builder::get_uint32, ret::reference) - .def("get_uint64", &ir::builder::get_uint64, ret::reference) + .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) + .def("get_uint32", &ir::builder::get_int32, ret::reference) + .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) + .def("get_uint64", &ir::builder::get_int64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference); + .def("get_range", &ir::builder::get_range, ret::reference) + // Types + .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) + .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference) + .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference) + .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) + .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference) + .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) + .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) + .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) + .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) + .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) + .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) + // terminator instructions + .def("create_br", &ir::builder::create_br, ret::reference) + .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) + .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) + // Cast instructions + .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) + .def("create_cast", &ir::builder::create_cast, ret::reference) + .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) + .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) + .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) + .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) + .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) + .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) + .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) + .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) + .def("create_downcast", &ir::builder::create_downcast, ret::reference) + // phi + .def("create_phi", &ir::builder::create_phi, ret::reference) + // Binary instructions + .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) + .def("create_fmul", &ir::builder::create_fmul, ret::reference) + .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) + .def("create_frem", &ir::builder::create_frem, ret::reference) + .def("create_fadd", &ir::builder::create_fadd, ret::reference) + .def("create_fsub", &ir::builder::create_fsub, ret::reference) + .def("create_mul", &ir::builder::create_mul, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) + .def("create_udiv", &ir::builder::create_udiv, ret::reference) + .def("create_srem", &ir::builder::create_srem, ret::reference) + .def("create_urem", &ir::builder::create_urem, ret::reference) + .def("create_add", &ir::builder::create_add, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sub", &ir::builder::create_sub, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_shl", &ir::builder::create_shl, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_lshr", &ir::builder::create_lshr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_ashr", &ir::builder::create_ashr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + // GEP + .def("create_gep", &ir::builder::create_gep, ret::reference) + // Comparison (int) + .def("create_icmp", &ir::builder::create_icmp, ret::reference) + .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) + .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) + .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) + .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) + .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) + .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) + .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) + .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) + .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) + .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) + // Comparison (float) + .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) + .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) + .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) + .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) + .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) + .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) + .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) + .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) + .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) + .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) + .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) + .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) + .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) + // Logical + .def("create_and", &ir::builder::create_and, ret::reference) + .def("create_xor", &ir::builder::create_xor, ret::reference) + .def("create_or", &ir::builder::create_or, ret::reference) + // Input/Output + .def("create_load", &ir::builder::create_load, ret::reference) + .def("create_store", &ir::builder::create_store, ret::reference) + .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) + .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) + // Block instruction + .def("create_splat", &ir::builder::create_splat, ret::reference) + .def("create_reshape", &ir::builder::create_reshape, ret::reference) + .def("create_cat", &ir::builder::create_cat, ret::reference) + .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) + // atomic + .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) + .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + + // Built-in instruction + .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) + .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) + .def("create_exp", &ir::builder::create_exp, ret::reference) + .def("create_cos", &ir::builder::create_cos, ret::reference) + .def("create_sin", &ir::builder::create_sin, ret::reference) + .def("create_log", &ir::builder::create_log, ret::reference) + .def("create_dot", &ir::builder::create_dot, ret::reference) + .def("create_trans", &ir::builder::create_trans, ret::reference) + .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) + .def("create_reduce", &ir::builder::create_reduce, ret::reference) + .def("create_select", &ir::builder::create_select, ret::reference) + // Intrinsics + // These have no place in the IR, and hopefully they can be removed at some point + .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) + .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) + .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) + .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) + .def("create_barrier", &ir::builder::create_barrier, ret::reference) + .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) + .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); } void init_triton(py::module &m) { @@ -770,5 +919,4 @@ void init_triton(py::module &m) { init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); - init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 1df3a0b49..f30b203bb 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -37,7 +37,7 @@ matmul_data = { (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.680}, + (2048, 2048, 2048): {'float16': 0.695}, (4096, 4096, 4096): {'float16': 0.831}, (8192, 8192, 8192): {'float16': 0.849}, # tall-skinny diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a49b47585..3561f7af4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,5 +1,4 @@ # flake8: noqa: F821,F841 -import copy import itertools import re from typing import Optional, Union @@ -585,7 +584,6 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -993,27 +991,6 @@ def test_noop(device='cuda'): kernel[(1, )](x) -@pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), - (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') -]) -def test_value_specialization(value: int, value_type: str, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - x = torch.tensor([3.14159], device='cuda') - pgm = kernel[(1, )](value, x) - - # Parse out the type of the 'VALUE' parameter from the Triton IR. - triton_ir = pgm.asm['ttir'] - ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) - ir_value_type = None if ir_value_match is None else ir_value_match.group(1) - assert ir_value_type == value_type - - @pytest.mark.parametrize( "value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 8ac01bcc8..d866d6983 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,4 +1,5 @@ import os +import re import shutil import pytest @@ -102,3 +103,30 @@ def test_specialize(mode): for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target + + +@pytest.mark.parametrize("value, value_type", [ + (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), + (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), + (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + cache_str = None + + def get_cache_str(*args, **kwargs): + nonlocal cache_str + cache_str = kwargs['key'].split('-') + triton.code_gen.JITFunction.cache_hook = get_cache_str + reset_tmp_dir() + x = torch.tensor([3.14159], device='cuda') + kernel[(1, )](value, x) + triton.code_gen.JITFunction.cache_hook = None + + cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) + spec_type = None if cache_str_match is None else cache_str_match.group(1) + assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index f9982939c..37ba46efc 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,7 +6,8 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ + JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 09254c967..a253e2c4c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast import builtins import functools @@ -11,7 +13,7 @@ import tempfile import textwrap import time import warnings -from typing import Dict +from typing import Dict, Optional, Set, Tuple, Union import torch from filelock import FileLock @@ -22,48 +24,13 @@ from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): - def get_value(self, name): - # search node.id in local scope - ret = None - if name in self.lscope: - ret = self.lscope[name] - # search node.id in global scope - elif name in self.gscope: - ret = self.gscope[name] - # search node.id in builtins - elif name in self.builtins: - ret = self.builtins[name] - else: - raise ValueError(f'{name} is not defined') - if isinstance(ret, triton.language.block): - handle = self.module.get_value(name) - return triton.language.block(handle) - return ret - - def set_value(self, name, value): - if isinstance(value, _triton.ir.value): - value = triton.language.block(value) - if isinstance(value, triton.language.block): - self.module.set_value(name, value.handle) - self.module.set_type(name, value.handle.type) - self.lscope[name] = value - - def is_triton_object(self, value): - return isinstance(value, triton.language.block) - - def visit_compound_statement(self, stmts): - for stmt in stmts: - self.last_ret = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) - def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) self.prototype = prototype self.gscope = gscope self.lscope = dict() + self.is_arg_lscope = dict() # name => is_arg: {str: bool} self.attributes = attributes self.constants = constants self.kwargs = kwargs @@ -77,6 +44,146 @@ class CodeGenerator(ast.NodeVisitor): 'isinstance': isinstance, 'getattr': getattr, } + # SSA-construction + # [name, bb] => triton.language.tensor + self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} + # bb => {name => phi} + self.incomplete_phis = {} + self.sealed_blocks: Set[_triton.ir.basic_block] = set() + + def get_value(self, name): + ''' This function: + 1. make sure `name` is defined + 2. if `name` is triton.language.tensor, get stored tensor by calling + `self._get_tensor()` + ''' + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]: + return self._get_tensor(name) + return ret + + def set_value(self, name: str, + value: Union[triton.language.tensor, triton.language.constexpr], + is_arg: bool = False) -> None: + ''' This function: + called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + # if this value is an argument, we don't need to create phis for it + self.is_arg_lscope[name] = is_arg + if isinstance(value, triton.language.tensor) and not is_arg: + self._set_value(name, self.builder.get_insert_block(), value) + + # + # SSA-construction + # + def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: + if not bb: + bb = self.builder.get_insert_block() + # local value numbering + if (name, bb) in self.lvalues: + return self.lvalues[(name, bb)] + # global value numbering + saved_insert_point = self.builder.get_insert_point() + result = self._get_tensor_recursive(name, bb) + self.builder.set_insert_point(saved_insert_point) + return result + + def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: + preds = bb.get_predecessors() + type = self.lscope[name].type + # some preds haven't been filled, create a phi as a proxy of the value + if bb not in self.sealed_blocks: + result = self._make_phi(type, len(preds), bb) + if bb in self.incomplete_phis: + self.incomplete_phis[bb][name] = result + else: + self.incomplete_phis[bb] = {name: result} + elif len(preds) == 1: + # one predecessor: no phi needed, try get value from pred + result = self._get_tensor(name, preds[0]) + else: # multiple preds + assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)' + phi = self._make_phi(type, len(preds), bb) + self._set_value(name, bb, phi) + result = self._add_phi_operands(name, phi) + self._set_value(name, bb, result) + return result + + # returns a new phi tensor, which encausulate an ir.phi_node + def _make_phi(self, + type: triton.language.dtype, + num_values: int, + bb: _triton.ir.basic_block) -> triton.language.tensor: + instr = bb.get_first_non_phi() + self.builder.set_insert_point((bb, instr)) + ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) + if instr: + self.builder.set_insert_block(bb) + return triton.language.tensor(ir_phi, type) + + # complete a phi node. (TODO: rename this as _complete_phis?) + # Note: since we try to remove tryival phi, the return tensor might not be a phi + def _add_phi_operands(self, name: str, + phi: triton.language.tensor) -> triton.language.tensor: + bb = phi.handle.get_parent() + for pred in bb.get_predecessors(): + v = self._get_tensor(name, pred) + phi.handle.add_incoming(v.handle, pred) + phi = self._try_remove_trivial_phi(phi) + return phi + + def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: + self.lvalues[(name, bb)] = value + # TODO: why we need this? + self.module.set_instr_metadata(name, value.handle) + + def _seal_block(self, bb: _triton.ir.basic_block): + # complete all incomplete phis + if bb in self.incomplete_phis: + for name, phi in self.incomplete_phis[bb].items(): + result = self._add_phi_operands(name, phi) + # it's possible that this phi is trivial + if self._get_tensor(name, bb).handle == phi.handle: + self._set_value(name, bb, result) + del self.incomplete_phis[bb] + self.sealed_blocks.add(bb) + + def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: + unique_handles = {op for op in phi.handle.ops() if op != phi.handle} + if len(unique_handles) != 1: # non-trivial phi + return phi + v = unique_handles.pop() + phi.handle.replace_all_uses_with(v) + phi.handle.erase_from_parent() + # TODO: remove trivial phis recursively + return triton.language.tensor(v, phi.type) + + def is_triton_tensor(self, value): + return isinstance(value, triton.language.tensor) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.last_ret = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + return stmts and isinstance(stmt, ast.Return) def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -113,7 +220,7 @@ class CodeGenerator(ast.NodeVisitor): if inline: pass else: - fn = self.module.get_or_insert_function(node.name, self.prototype) + fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -130,17 +237,17 @@ class CodeGenerator(ast.NodeVisitor): attr = _triton.ir.attribute(attr, self.attributes[i]) fn.add_attr(idx + 1, attr) fn.args[idx].name = arg_name - arg_values.append(fn.args[idx]) + arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) idx += 1 for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) + self.set_value(arg_name, arg_value, is_arg=True) if inline: self.visit_compound_statement(node.body) return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self.module.seal_block(entry) + self._seal_block(entry) self.builder.set_insert_block(entry) # visit function body self.visit_compound_statement(node.body) @@ -187,11 +294,12 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): + # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value - if not isinstance(value, triton.language.block): - value = triton.language.core._to_ir(value, self.builder) + if not isinstance(value, triton.language.tensor): + value = triton.language.core._to_tensor(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -220,9 +328,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value fn = { ast.Add: '__add__', @@ -238,9 +346,9 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif self.is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -248,15 +356,15 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if isinstance(cond, triton.language.block): + if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self.module.seal_block(then_bb) + self._seal_block(then_bb) if else_bb: - self.module.seal_block(else_bb) + self._seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -271,7 +379,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self.module.seal_block(endif_bb) + self._seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -296,9 +404,9 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) @@ -312,9 +420,9 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif self.is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -325,21 +433,21 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.core.constexpr): + if isinstance(op, triton.language.constexpr): op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if self.is_triton_object(op): + if self.is_triton_tensor(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): cond = self.visit(node.test) @@ -350,9 +458,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self._seal_block(stop_bb) + self._seal_block(loop_bb) + self._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -362,7 +470,7 @@ class CodeGenerator(ast.NodeVisitor): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] @@ -405,8 +513,8 @@ class CodeGenerator(ast.NodeVisitor): step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): self.visit(step_node) @@ -421,9 +529,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self._seal_block(stop_bb) + self._seal_block(loop_bb) + self._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -451,7 +559,7 @@ class CodeGenerator(ast.NodeVisitor): args = [self.visit(arg) for arg in node.args] if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) - if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ + if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): @@ -581,7 +689,7 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] - if isinstance(obj, triton.language.core.constexpr): + if isinstance(obj, triton.language.constexpr): obj = obj.value if isinstance(obj, int): if -2**31 <= obj < 2**31: @@ -613,34 +721,34 @@ class Kernel: return 'scalar', name @staticmethod - def _to_triton_ir(context, obj): + def _to_triton_ir(obj): which, name = obj type_map = { - 'I': _triton.ir.type.get_int32, - 'L': _triton.ir.type.get_int64, - 'f': _triton.ir.type.get_fp32, - 'B': _triton.ir.type.get_int1, - 'f8': _triton.ir.type.get_fp8, - 'f16': _triton.ir.type.get_fp16, - 'bf16': _triton.ir.type.get_bf16, - 'f32': _triton.ir.type.get_fp32, - 'f64': _triton.ir.type.get_fp64, - 'i1': _triton.ir.type.get_int1, - 'i8': _triton.ir.type.get_int8, - 'i16': _triton.ir.type.get_int16, - 'i32': _triton.ir.type.get_int32, - 'i64': _triton.ir.type.get_int64, - 'u8': _triton.ir.type.get_uint8, - 'u16': _triton.ir.type.get_uint16, - 'u32': _triton.ir.type.get_uint32, - 'u64': _triton.ir.type.get_uint64, + 'I': triton.language.int32, + 'L': triton.language.int64, + 'f': triton.language.float32, + 'B': triton.language.int1, + 'f8': triton.language.float8, + 'f16': triton.language.float16, + 'bf16': triton.language.bfloat16, + 'f32': triton.language.float32, + 'f64': triton.language.float64, + 'i1': triton.language.int1, + 'i8': triton.language.int8, + 'i16': triton.language.int16, + 'i32': triton.language.int32, + 'i64': triton.language.int64, + 'u8': triton.language.uint8, + 'u16': triton.language.uint16, + 'u32': triton.language.uint32, + 'u64': triton.language.uint64, } # convert torch.Tensor to Triton IR pointers if which == 'ptr': - elt_ty = type_map[name](context) - return _triton.ir.type.make_ptr(elt_ty, 1) + elt_ty = type_map[name] + return triton.language.pointer_type(elt_ty, 1) # default path returns triton.ir.type directly - return type_map[name](context) + return type_map[name] @staticmethod def pow2_divisor(N): @@ -920,25 +1028,31 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree + # Called by CodeGenerator.visit_Call() def __call__(self, *args, generator: CodeGenerator, **kwargs): try: from inspect import getcallargs arg_values = getcallargs(self.fn, *args, **kwargs) arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.block) + arg_values = [arg if isinstance(arg, triton.language.tensor) else triton.language.constexpr(arg) for arg in arg_values] + # Record values in the caller (parent scope) gscope = generator.gscope.copy() lscope = generator.lscope.copy() - values = generator.module.get_values().copy() - types = generator.module.get_types().copy() + + # TODO: clear values other than args + lvalues = generator.lvalues.copy() + # types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope - generator.module.set_values(values) - generator.module.set_types(types) + + generator.lvalues = lvalues + # generator.module.set_types(types) + return ret except Exception as e: node = generator.last_node @@ -1023,9 +1137,9 @@ class JITFunction: # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] - ret_type = _triton.ir.type.get_void(context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) + arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] + ret_type = triton.language.void + prototype = triton.language.function_type(ret_type, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df25e59fb..81b9fe790 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,63 +1,36 @@ +from __future__ import annotations + +from enum import Enum from functools import wraps +from typing import List import triton -from triton._C.libtriton.triton import frontend, ir +from . import semantic +from triton._C.libtriton.triton import ir -# convert block/dtype to ir values -def _to_ir(x, builder): +def _to_tensor(x, builder): if isinstance(x, bool): - return builder.get_int1(x) + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values elif isinstance(x, int): if -2**31 <= x < 2**31: - return builder.get_int32(x) + return tensor(builder.get_int32(x), int32) elif 2**31 <= x < 2**32: - return builder.get_uint32(x) + return tensor(builder.get_uint32(x), uint32) elif -2**63 <= x < 2**63: - return builder.get_int64(x) + return tensor(builder.get_int64(x), int64) elif 2**63 <= x < 2**64: - return builder.get_uint64(x) + return tensor(builder.get_uint64(x), uint64) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - return builder.get_float32(x) + return tensor(builder.get_float32(x), float32) elif isinstance(x, constexpr): - return _to_ir(x.value, builder) - elif isinstance(x, block): - return x.handle - elif isinstance(x, dtype): - return x.handle(builder) - return x - - -def _patch(fn): - def _from_ir(x): - if isinstance(x, ir.value): - if x.type.is_void(): - return None - return block(x) + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): return x - - def wrapper(*args, **kwargs): - builder = args[-1] - assert isinstance(builder, ir.builder) - args = [_to_ir(x, builder) for x in args] - # for i, arg in enumerate(args): - # if arg is None: - # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") - kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} - ret = fn(*args, **kwargs) - if isinstance(ret, tuple): - return map(_from_ir, ret) - return _from_ir(ret) - - return wrapper - - -for name in dir(frontend): - fn = getattr(frontend, name) - if callable(fn): - setattr(frontend, name, _patch(fn)) + assert False, f'cannot convert {x} to tensor' def builtin(fn): @@ -72,20 +45,147 @@ def builtin(fn): class dtype: - def __init__(self, init): - self.init = init + SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return self.name == 'fp8' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def is_void(self): + raise RuntimeError("Not implemented") + + def is_block(self): + return False + + def is_ptr(self): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name,)) @property - def name(self) -> str: - # The init functions are named something like 'get_int8'. Strip the prefix. - nom = self.init.__name__ - prefix = 'get_' - assert nom.startswith(prefix) - return nom[len(prefix):] + def scalar(self): + return self - def handle(self, builder): - ctx = builder.context - return self.init(ctx) + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8' or self.name == 'uint8': + return builder.get_int8_ty() + elif self.name == 'int16' or self.name == 'uint16': + return builder.get_int16_ty() + elif self.name == 'int32' or self.name == 'uint32': + return builder.get_int32_ty() + elif self.name == 'int64' or self.name == 'uint64': + return builder.get_int64_ty() + elif self.name == 'fp8': + return builder.get_fp8_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to covert {self} to ir type') def __str__(self): return self.name @@ -99,36 +199,112 @@ class dtype: return f'triton.language.{self.name}' -class pointer_dtype: - def __init__(self, element_ty): +class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty + self.address_space = address_space - def handle(self, builder): - return ir.type.make_ptr(self.element_ty.handle(builder), 1) + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List[int]): + self.element_ty = element_ty + # FIXME: + # block_type's shape is a list of int + # while tensor's shape is a list of constexpr + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return f'<{self.shape}, {self.element_ty}>' + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: + self.ret_type = ret_type + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_type}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) + # scalar types -int1 = dtype(ir.type.get_int1) -int8 = dtype(ir.type.get_int8) -int16 = dtype(ir.type.get_int16) -int32 = dtype(ir.type.get_int32) -int64 = dtype(ir.type.get_int64) -uint8 = dtype(ir.type.get_uint8) -uint16 = dtype(ir.type.get_uint16) -uint32 = dtype(ir.type.get_uint32) -uint64 = dtype(ir.type.get_uint64) -float8 = dtype(ir.type.get_fp8) -float16 = dtype(ir.type.get_fp16) -bfloat16 = dtype(ir.type.get_bf16) -float32 = dtype(ir.type.get_fp32) -float64 = dtype(ir.type.get_fp64) +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8 = dtype('fp8') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') # pointer types -pi32_t = pointer_dtype(int32) +pi32_t = pointer_type(int32) # ----------------------- # constexpr @@ -149,7 +325,6 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" - # def __add__(self, other): return self.value + other.value @@ -219,31 +394,33 @@ class constexpr: return self.value(*args, **kwds) -class block: +class tensor: + # infer dtype from ir type @staticmethod - def _init_dtype(ir_type): + def _to_dtype(ir_type): + # block type + if ir_type.is_block(): + scalar_ty = tensor._to_dtype(ir_type.scalar) + return block_type(scalar_ty, ir_type.get_block_shapes()) + # pointer type + if ir_type.is_ptr(): + element_ty = tensor._to_dtype(ir_type.element) + return pointer_type(element_ty) # primitive type + if ir_type.is_void(): return void if ir_type.is_int1(): return int1 if ir_type.is_int8(): return int8 if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 - if ir_type.is_uint8(): return uint8 - if ir_type.is_uint16(): return uint16 - if ir_type.is_uint32(): return uint32 - if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 - # pointer type - if ir_type.is_ptr(): - element_ty = block._init_dtype(ir_type.element) - return pointer_dtype(element_ty) - raise ValueError(f"Unsupported type {ir_type}") + raise ValueError(f"Unsupported type {ir_type.repr()}") - def __init__(self, handle): + def __init__(self, handle, type: dtype): # IR handle self.handle = handle # Block shape @@ -254,9 +431,9 @@ class block: for s in self.shape: self.numel *= s self.numel = constexpr(self.numel) - # Data-type wrapper - self.dtype = block._init_dtype(self.handle.type.scalar) - # Shape is a constexpr + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: @@ -265,116 +442,139 @@ class block: @builtin def __add__(self, other, _builder=None): - return frontend.add(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - return frontend.sub(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) def __rsub__(self, other, _builder=None): - return frontend.sub(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) @builtin def __mul__(self, other, _builder=None): - return frontend.mul(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - return frontend.truediv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) def __rtruediv__(self, other, _builder=None): - return frontend.truediv(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - return frontend.floordiv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) @builtin def __mod__(self, other, _builder=None): - return frontend.mod(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) # unary operators @builtin def __neg__(self, _builder=None): - return frontend.minus(self, _builder) + return semantic.minus(self, _builder) @builtin def __invert__(self, _builder=None): - return frontend.invert(self, _builder) + return semantic.invert(self, _builder) # bitwise operators @builtin def __and__(self, other, _builder=None): - return frontend.and_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) @builtin def __or__(self, other, _builder=None): - return frontend.or_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) @builtin def __xor__(self, other, _builder=None): - return frontend.xor_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) @builtin def __lshift__(self, other, _builder=None): - return frontend.shl(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) @builtin def __rshift__(self, other, _builder=None): - return frontend.lshr(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.lshr(self, other, _builder) # comparison operators # > @builtin def __gt__(self, other, _builder=None): - return frontend.greater_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - return frontend.greater_than(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - return frontend.greater_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + @builtin def __rge__(self, other, _builder=None): - return frontend.greater_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - return frontend.less_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - return frontend.less_than(other, self, _builder) + return semantic.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - return frontend.less_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - return frontend.less_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - return frontend.equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) @builtin def __ne__(self, other, _builder=None): - return frontend.not_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) @builtin def __getitem__(self, slices, _builder=None): @@ -389,20 +589,25 @@ class block: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = frontend.reshape(self, dst_shape, _builder) + ret = semantic.reshape(self, dst_shape, _builder) return ret @builtin def to(self, dtype, bitcast=False, _builder=None): - dtype = dtype.handle(_builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value if bitcast: - return frontend.bitcast(self, dtype, _builder) - return frontend.cast(self, dtype, _builder) + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder) # ----------------------- # SPMD Programming Model # ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v @builtin @@ -414,13 +619,14 @@ def program_id(axis, _builder=None): :type axis: int """ # if axis == -1: - # pid0 = frontend.program_id(0, _builder) - # pid1 = frontend.program_id(1, _builder) - # pid2 = frontend.program_id(2, _builder) - # npg0 = frontend.num_programs(0, _builder) - # npg1 = frontend.num_programs(0, _builder) + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 - return frontend.program_id(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) @builtin @@ -431,7 +637,8 @@ def num_programs(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - return frontend.num_programs(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) # ----------------------- @@ -449,13 +656,15 @@ def arange(start, end, _builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - return frontend.arange(start, end, _builder) + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) @builtin def zeros(shape, dtype, _builder=None): """ - Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints @@ -468,7 +677,8 @@ def zeros(shape, dtype, _builder=None): if not isinstance(d.value, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") shape = [x.value for x in shape] - return frontend.zeros(shape, dtype, _builder) + dtype = _constexpr_to_value(dtype) + return semantic.zeros(shape, dtype, _builder) # ----------------------- @@ -481,25 +691,25 @@ def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. - :param input: The first input block. + :param input: The first input tensor. :type input: Block - :param other: The second input block. + :param other: The second input tensor. :type other: Block """ - return frontend.broadcast(input, other, _builder) + return semantic.broadcast_impl_value(input, other, _builder) @builtin def broadcast_to(input, shape, _builder=None): """ - Tries to broadcast the given block to a new :code:`shape`. + Tries to broadcast the given tensor to a new :code:`shape`. - :param input: The input block. + :param input: The input tensor. :type input: Block :param shape: The desired shape. :type shape: Tuple[int] """ - return frontend.broadcast_to(input, shape, _builder) + return semantic.broadcast_impl_shape(input, shape, _builder) @builtin @@ -507,27 +717,27 @@ def cat(input, other, _builder=None): """ Concatenate the given blocks - :param input: The first input block. + :param input: The first input tensor. :type input: - :param other: The second input block. + :param other: The second input tensor. :type other: """ - return frontend.cat(input, other, _builder) + return semantic.cat(input, other, _builder) @builtin def reshape(input, shape, _builder=None): """ - Tries to reshape the given block to a new shape. + Tries to reshape the given tensor to a new shape. - :param input: The input block. + :param input: The input tensor. :type input: :param shape: The desired shape. :type shape: Tuple[int] """ shape = [x.value for x in shape] - return frontend.reshape(input, shape, _builder) + return semantic.reshape(input, shape, _builder) # ----------------------- @@ -542,12 +752,13 @@ def dot(input, other, allow_tf32=True, _builder=None): The two blocks must be two dimensionals and have compatible inner dimensions. - :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} - :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - return frontend.dot(input, other, allow_tf32, _builder) + allow_tf32 = _constexpr_to_value(allow_tf32) + return semantic.dot(input, other, allow_tf32, _builder) # ----------------------- @@ -558,7 +769,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ - Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. + Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. @@ -573,24 +784,36 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) + # mask, other can be constexpr + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :param pointer: The memory locations where the elements of :code:`value` are stored. :type pointer: Block of dtype=triton.PointerDType - :param value: The block of elements to be stored. + :param value: The tensor of elements to be stored. :type value: Block :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - return frontend.store(pointer, value, mask, _builder) + # value can be constexpr + value = _to_tensor(value, _builder) + if mask is not None: + mask = _to_tensor(mask, _builder) + return semantic.store(pointer, value, mask, _builder) # ----------------------- @@ -621,49 +844,58 @@ def _add_atomic_docstr(name): @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): - return frontend.atomic_cas(pointer, cmp, val, _builder) + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(cmp, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): - return frontend.atomic_xchg(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xchg(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): - return frontend.atomic_add(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_add(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, _builder=None): - return frontend.atomic_max(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_max(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, _builder=None): - return frontend.atomic_min(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_min(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, _builder=None): - return frontend.atomic_and(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_and(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, _builder=None): - return frontend.atomic_or(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_or(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, _builder=None): - return frontend.atomic_xor(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -674,7 +906,7 @@ def atomic_xor(pointer, val, mask=None, _builder=None): @builtin def where(condition, x, y, _builder=None): """ - Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. @@ -688,7 +920,10 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - return frontend.where(condition, x, y, _builder) + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) # ----------------------- @@ -697,12 +932,15 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): - return frontend.umulhi(x, y, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.umulhi(x, y, _builder) @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): - return frontend.fdiv(x, y, ieee_rounding, _builder) + ieee_rounding = _constexpr_to_value(ieee_rounding) + return semantic.fdiv(x, y, ieee_rounding, _builder) def _add_math_1arg_docstr(name): @@ -723,31 +961,31 @@ def _add_math_1arg_docstr(name): @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): - return frontend.exp(x, _builder) + return semantic.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): - return frontend.log(x, _builder) + return semantic.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): - return frontend.cos(x, _builder) + return semantic.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): - return frontend.sin(x, _builder) + return semantic.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): - return frontend.sqrt(x, _builder) + return semantic.sqrt(x, _builder) # ----------------------- @@ -758,7 +996,7 @@ def _add_reduction_docstr(name): def _decorator(func): docstr = """ - Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done @@ -772,25 +1010,29 @@ def _add_reduction_docstr(name): @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): - return frontend.max(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.max(input, axis, _builder) @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): - return frontend.min(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.min(input, axis, _builder) @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): - return frontend.sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.sum(input, axis, _builder) @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): - return frontend.xor_sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.xor_sum(input, axis, _builder) # ----------------------- @@ -800,7 +1042,7 @@ def xor_sum(input, axis, _builder=None): @builtin def debug_barrier(_builder=None): - return frontend.debug_barrier(_builder) + return semantic.debug_barrier(_builder) @builtin @@ -808,7 +1050,8 @@ def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - return frontend.multiple_of(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.multiple_of(input, value) @builtin @@ -816,7 +1059,8 @@ def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - return frontend.max_contiguous(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.max_contiguous(input, value) # ----------------------- @@ -846,9 +1090,9 @@ def minimum(x, y): """ Computes the element-wise minimum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x < y, x, y) @@ -859,9 +1103,9 @@ def maximum(x, y): """ Computes the element-wise maximum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x > y, x, y) @@ -887,7 +1131,7 @@ def ravel(x): """ Returns a contiguous flattened view of :code:`x` - :param x: the input block + :param x: the input tensor :type x: Block """ return triton.language.reshape(x, [x.numel]) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py new file mode 100644 index 000000000..4063b86fc --- /dev/null +++ b/python/triton/language/semantic.py @@ -0,0 +1,1037 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Tuple + +from . import core as tl +from triton._C.libtriton.triton import ir + + +# Create custom exception that prints message "hello" +class IncompatibleTypeErrorimpl(Exception): + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorimpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + assert False + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + assert False + # 4 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorimpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, + arithmetic_check=True, div_or_mod=False + ) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + assert False + + +def sub(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), + input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + assert False + + +def mul(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + assert False + + +def truediv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + assert False + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + assert False + + +def fdiv(input: tl.tensor, + other: tl.tensor, + ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise ValueError("both operands of fdiv must have floating poscalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + ret.set_fdiv_ieee_rounding(ieee_rounding) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + assert False + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def lshr(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, + builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, + builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + assert False + + +def greater_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + assert False + + +def equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + assert False + + +def not_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + assert False + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + shape = [end - start] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.get_range(start, end), ret_ty) + + +def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + _0 = ir.constant.get_null_value(dtype.to_ir(builder)) + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_splat(_0, shape), ret_ty) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def reshape(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("cannot reshape block of different shape") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + # TODO: check types + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) + + +def broadcast_impl_shape(input: tl.tensor, + shape: List[int], + builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + if len(lhs_shape) != len(rhs_shape): + raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") + ret_shape = [] + for i in range(len(lhs_shape)): + left = lhs_shape[i] + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif right == 1: + ret_shape.append(left) + elif left == right: + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + +def bitcast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), + dst_ty) + + +def cast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # bf16 <=> (not fp32) + if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ + (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # FP Truncation + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # FP Extension + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Int cast + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + return tl.tensor(builder.create_int_cast(input.handle, + dst_ty.to_ir(builder), sign_extend), + dst_ty) + + # Float to Int + if src_sca_ty.is_floating() and dst_sca_ty.is_int(): + # TODO: is this correct? + if dst_sca_ty.is_bool(): + return tl.tensor(builder.create_fp_to_ui(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_fp_to_si(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # int => float + if src_sca_ty.is_int() and dst_sca_ty.is_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # ptr => int + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), + tl.tensor(builder.get_int64(0), tl.int64), + builder) + + if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + # Ptr . Ptr + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + # * . Bool + if dst_sca_ty.is_bool(): + if src_sca_ty.is_ptr(): + input = cast(input, tl.int64, builder) + other = builder.get_int64(0) + if src_ty.is_bool(): + other = builder.create_splat(other, src_ty.get_block_shapes()) + return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def load(ptr: tl.tensor, + mask: Optional[tl.tensor], + other: Optional[tl.tensor], + cache_modifier: str, + eviction_policy: str, + is_volatile: bool, + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + if other: + other = cast(other, ptr.type.scalar.element_ty, builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cache modifier + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + + # eviction policy + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + dst_ty = elt_ty + + if not mask and not other: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), + dst_ty) + if not mask: + raise ValueError("`other` cannot be provided without `mask`") + + if not other: + other_ir = ir.undef.get(elt_ty.to_ir(builder)) + if ptr.type.is_block(): + other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes()) + other = tl.tensor(other_ir, dst_ty) + + return tl.tensor(builder.create_masked_load(ptr.handle, + mask.handle, + other.handle, + cache, eviction, is_volatile), + dst_ty) + + +def store(ptr: tl.tensor, + val: tl.tensor, + mask: Optional[tl.tensor], + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cast to target data-type + val = cast(val, elt_ty, builder) + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, + cmp: tl.tensor, + val: tl.tensor, + builder: ir.builder) -> tl.tensor: + # TODO: type checking + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_min(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + i_ptr.handle, + i_val.handle, + and_(mask, pos, builder).handle), + i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + i_ptr.handle, + i_val.handle, + and_(mask, neg, builder).handle), + i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_add(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_and(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_or(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xor(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xchg(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def dot(lhs: tl.tensor, + rhs: tl.tensor, + allow_tf32: bool, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + if lhs.type.scalar.is_int(): + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + else: + _0 = builder.get_float32(0) + ret_scalar_ty = tl.float32 + M = lhs.type.shape[0] + N = rhs.type.shape[1] + _0 = builder.create_splat(_0, [M, N]) + ret_ty = tl.block_type(ret_scalar_ty, [M, N]) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + +def where(condition: tl.tensor, + x: tl.tensor, + y: tl.tensor, + builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) + y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) + + # TODO: we need to check x's and y's shape? + x_ty = x.type.scalar + y_ty = y.type.scalar + ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) + x = cast(x, ty, builder) + y = cast(y, ty, builder) + if x.type.is_block(): + ret_ty = tl.block_type(ty, x.type.shape) + else: + ret_ty = ty + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reductions +# ===----------------------------------------------------------------------=== + +def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, + FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: + scalar_ty = input.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: + input = cast(input, tl.int32, builder) + + # get result type + shape = input.type.shape + ret_shape = [] + for i, s in enumerate(shape): + if i != axis: + ret_shape.append(s) + if len(ret_shape) == 0: + res_ty = scalar_ty + else: + res_ty = tl.block_type(scalar_ty, ret_shape) + + if scalar_ty.is_floating(): + return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) + elif scalar_ty.is_int(): + return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) + assert False + + +def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) + + +def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) + + +def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) + + +def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) + + +# ===----------------------------------------------------------------------=== +# Math +# ===----------------------------------------------------------------------=== + +def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + x, y = binary_op_type_checking_impl(x, y, builder) + return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) + + +def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_exp(x.handle), x.type) + + +def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_log(x.handle), x.type) + + +def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_cos(x.handle), x.type) + + +def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sin(x.handle), x.type) + + +def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sqrt(x.handle), x.type) + + +## + +def multiple_of(x: tl.tensor, value: int) -> tl.tensor: + x.handle.multiple_of(value) + return x + + +def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: + x.handle.max_contiguous(value) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(''), tl.void) From a4f68165cd7b2b7dd395175cf2b1d17b3163a8c4 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 23 Mar 2022 00:09:49 -0500 Subject: [PATCH 079/215] [FRONTEND] Hot fix for lineno (#481) Override __reduce__ to make CompilationError pickable and print out error messages --- python/triton/code_gen.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a253e2c4c..23d460f29 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -654,9 +654,13 @@ class CompilationError(Exception): self.message = f'at {node.lineno}:{node.col_offset}:\n' self.message += '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' + self.src = src + self.node = node super().__init__(self.message) + + def __reduce__(self): # this is necessary to make CompilationError picklable - self.args = (src, node) + return (type(self), (self.src, self.node)) class OutOfResources(Exception): @@ -664,8 +668,14 @@ class OutOfResources(Exception): self.message = f'out of resource: {name}, '\ f'Required: {required}, '\ f'Hardware limit: {limit}' + self.required = required + self.limit = limit + self.name = name super().__init__(self.message) - self.args = (required, limit, name) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) class Kernel: From ea6d1f1b8511c95a8c0f68460ffa62cd31a1a8a8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 23 Mar 2022 00:24:45 -0700 Subject: [PATCH 080/215] [DRIVER] LLVM driver fixup (#482) Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function. --- lib/driver/llvm.cc | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index e7bb47bef..0d1c9c3d2 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -70,18 +70,14 @@ namespace triton{ namespace driver{ void init_llvm() { - static bool init = false; - if(!init){ - LLVMInitializeNVPTXTargetInfo(); - LLVMInitializeNVPTXTarget(); - LLVMInitializeNVPTXTargetMC(); - LLVMInitializeNVPTXAsmPrinter(); - LLVMInitializeAMDGPUTargetInfo(); - LLVMInitializeAMDGPUTarget(); - LLVMInitializeAMDGPUTargetMC(); - LLVMInitializeAMDGPUAsmPrinter(); - init = true; - } + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + LLVMInitializeAMDGPUTargetInfo(); + LLVMInitializeAMDGPUTarget(); + LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmPrinter(); } @@ -169,8 +165,6 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // verify and store llvm llvm::legacy::PassManager pm; pm.add(llvm::createVerifierPass()); - // pm.add(llvm::createDeadCodeEliminationPass()); - // pm.add(llvm::createEarlyCSEPass()); pm.run(*module); // module->print(llvm::outs(), nullptr); From 76a9ee50a850d402c8d0ea6052755ae405088f06 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 24 Mar 2022 17:16:50 -0700 Subject: [PATCH 081/215] Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483) This reverts commit 539961072c277221e4e1dfda81e718daf9bbc1c7. --- include/triton/ir/builder.h | 31 +- include/triton/ir/context_impl.h | 1 + include/triton/ir/dispatch.h | 113 +++ include/triton/ir/module.h | 33 +- include/triton/ir/type.h | 20 +- lib/ir/builder.cc | 53 +- lib/ir/context.cc | 18 +- lib/ir/dispatch.cc | 882 +++++++++++++++++ lib/ir/instructions.cc | 4 +- lib/ir/module.cc | 140 +++ lib/ir/type.cc | 14 + python/src/triton.cc | 392 +++----- python/test/regression/test_performance.py | 2 +- python/test/unit/language/test_core.py | 23 + python/test/unit/runtime/test_cache.py | 28 - python/triton/__init__.py | 3 +- python/triton/code_gen.py | 314 ++---- python/triton/language/core.py | 606 ++++-------- python/triton/language/semantic.py | 1037 -------------------- 19 files changed, 1670 insertions(+), 2044 deletions(-) create mode 100644 include/triton/ir/dispatch.h create mode 100644 lib/ir/dispatch.cc delete mode 100644 python/triton/language/semantic.py diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index fe85be947..2b6bc6ab3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -38,8 +38,10 @@ public: iterator get_insert_point() { return insert_point_;} // Constants value *get_int1(bool val); - value *get_int32(uint32_t val); - value *get_int64(uint64_t val); + value *get_int32(int32_t val); + value *get_int64(int64_t val); + value *get_uint32(uint32_t val); + value *get_uint64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -50,9 +52,11 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); - type *get_fp8_ty(); + type *get_uint8_ty(); + type *get_uint16_ty(); + type *get_uint32_ty(); + type *get_uint64_ty(); type *get_half_ty(); - type *get_bf16_ty(); type *get_float_ty(); type *get_double_ty(); // Insert @@ -70,9 +74,7 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); // Cast instructions - value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); - value* create_int_to_ptr(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty); @@ -91,11 +93,11 @@ public: value *create_frem(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sdiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); @@ -143,22 +145,11 @@ public: value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); - // Atomic instruction - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); - value *create_atomic_max(value *ptr, value *val, value *msk); - value *create_atomic_umax(value *ptr, value *val, value *msk); - value *create_atomic_min(value *ptr, value *val, value *msk); - value *create_atomic_umin(value *ptr, value *val, value *msk); - value *create_atomic_fadd(value *ptr, value *val, value *msk); - value *create_atomic_add(value *ptr, value *val, value *msk); - value *create_atomic_and(value *ptr, value *val, value *msk); - value *create_atomic_or(value *ptr, value *val, value *msk); - value *create_atomic_xor(value *ptr, value *val, value *msk); - value *create_atomic_xchg(value *ptr, value *val, value *msk); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); value *create_sin(value* arg); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index ef20af6b7..081ea249d 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -26,6 +26,7 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; + integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, std::unique_ptr> ptr_tys; // Block types diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h new file mode 100644 index 000000000..ef14043dd --- /dev/null +++ b/include/triton/ir/dispatch.h @@ -0,0 +1,113 @@ +#pragma once + +#ifndef _TRITON_IR_DISPATCH_H_ +#define _TRITON_IR_DISPATCH_H_ + +#include "triton/ir/builder.h" +#include + +namespace triton{ +namespace ir{ + + +/*---------------------------------------------- + higher level functions that follow the likely + semantics of most expected frontends + ----------------------------------------------*/ + +struct semantic_error: public std::runtime_error { + semantic_error(const std::string& msg): + std::runtime_error(msg) { } +}; + +struct dispatch{ + typedef ir::type::block_shapes_t shape_t; + + + // programming model + static ir::value *program_id(int axis, ir::builder *builder); + static ir::value *num_programs(int axis, ir::builder *builder); + + // binary operators + static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); + static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); + + // unary operators + static ir::value *plus(ir::value *input, ir::builder *builder); + static ir::value *minus(ir::value *input, ir::builder *builder); + static ir::value *invert(ir::value *input, ir::builder *builder); + + // comparison operators + static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); + + // block creation + static ir::value* arange(int start, int end, ir::builder *builder); + static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); + + + // casting ops + static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); + static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); + static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); + static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); + static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); + static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); + + // memory operators + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, + const std::string& eviction_policy, int is_volatile, ir::builder *builder); + static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); + static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); + static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + + // linear algebra + static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); + + // indexing + static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); + + // reduction + static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); + + // math + static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); + static ir::value *exp(ir::value *x, ir::builder *builder); + static ir::value *log(ir::value *x, ir::builder *builder); + static ir::value *cos(ir::value *x, ir::builder *builder); + static ir::value *sin(ir::value *x, ir::builder *builder); + static ir::value *sqrt(ir::value *x, ir::builder *builder); + + // internal (debug/optimization) + static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); + static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); + static ir::value *debug_barrier(ir::builder *builder); +}; + +} +} + +#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index ea64dfc6e..30881fd49 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -57,10 +57,26 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, builder &builder): name_(name), builder_(builder) {} - builder &get_builder() { return builder_; }; - const std::string& get_name() { return name_; }; + module(const std::string &name, builder& builder); + builder& get_builder(); + // Setters + void set_value(const std::string& name, basic_block* block, value *x); + void set_value(const std::string& name, value* x); + void set_const(const std::string& name); + void set_continue_fn(std::function fn); + // Getters + const std::map& get_values() { return values_; } + const std::map& get_types() { return types_; } + void set_values(const std::map& values) { values_ = values; } + void set_types(const std::map& types) { types_ = types; } + value *get_value(const std::string& name, basic_block* block); + value *get_value(const std::string& name); + void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } + const std::string& get_name(); + std::function get_continue_fn(); + // Seal block -- no more predecessors will be added + void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } @@ -73,14 +89,21 @@ public: const std::map& globals() const { return globals_; } // Metadata void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - const std::map &get_metadatas() const { return metadatas_; } + void print(std::ostream &os); private: std::string name_; - builder &builder_; + builder& builder_; + std::map values_; + std::map types_; + std::set const_; + std::set sealed_blocks_; + std::map> incomplete_phis_; functions_list_t functions_; symbols_map_t symbols_; + std::function continue_fn_; + std::map current_phi_; std::vector allocs_; std::map globals_; std::map metadatas_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index b1ef1ad22..47c9b5f85 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -16,6 +16,8 @@ class value; class integer_type; class constant_int; +enum class signedness { SIGNED, UNSIGNED }; + /* Type */ class type { public: @@ -59,6 +61,8 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; + signedness get_integer_signedness() const; + bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -81,6 +85,9 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } + bool is_integer_ty(unsigned bitwidth, signedness sn) { + return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; + } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -108,6 +115,10 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + static integer_type *get_uint8_ty(context &ctx); + static integer_type *get_uint16_ty(context &ctx); + static integer_type *get_uint32_ty(context &ctx); + static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -134,7 +145,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -157,18 +168,21 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth) - : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} + integer_type(context &ctx, unsigned bitwidth, signedness sn) + : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } public: // accessors unsigned get_bitwidth() const { return bitwidth_; } + signedness get_signedness() const { return signedness_; } + // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; + signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 9b8a2a45e..fff73e665 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -48,12 +48,18 @@ void builder::set_insert_point(basic_block *block){ value *builder::get_int1(bool val) { return constant_int::get(type::get_int1_ty(ctx_), val); } -value *builder::get_int32(uint32_t val) +value *builder::get_int32(int32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} -value *builder::get_int64(uint64_t val) +value *builder::get_uint32(uint32_t val) +{ return constant_int::get(type::get_uint32_ty(ctx_), val);} + +value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} +value *builder::get_uint64(uint64_t val) +{ return constant_int::get(type::get_uint64_ty(ctx_), val);} + value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -84,15 +90,21 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } -type *builder::get_fp8_ty() -{ return type::get_fp8_ty(ctx_); } +type *builder::get_uint8_ty() +{ return type::get_uint8_ty(ctx_); } + +type *builder::get_uint16_ty() +{ return type::get_uint16_ty(ctx_); } + +type *builder::get_uint32_ty() +{ return type::get_uint32_ty(ctx_); } + +type *builder::get_uint64_ty() +{ return type::get_uint64_ty(ctx_); } type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } -type *builder::get_bf16_ty() -{ return type::get_bf16_ty(ctx_); } - type *builder::get_float_ty() { return type::get_fp32_ty(ctx_); } @@ -127,8 +139,6 @@ value *builder::create_ret_void() { return create_cast(OPCODE, src, dst_ty);\ } -DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) -DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) @@ -321,28 +331,6 @@ value *builder::create_downcast(value *arg) { return insert(downcast_inst::create(arg)); } -// - -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} - -#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ - return create_atomic_rmw(OPCODE, ptr, val, mask);\ - } - -DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) -DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) -DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) -DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) -DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) -DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) -DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) -DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) -DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) -DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) - //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -359,6 +347,9 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); +} value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 0fc65ddc2..90b109b9b 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,12 +19,18 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1), - int8_ty(ctx, 8), - int16_ty(ctx, 16), - int32_ty(ctx, 32), - int64_ty(ctx, 64), - int128_ty(ctx, 128) {} + int1_ty(ctx, 1, signedness::SIGNED), + int8_ty(ctx, 8, signedness::SIGNED), + int16_ty(ctx, 16, signedness::SIGNED), + int32_ty(ctx, 32, signedness::SIGNED), + int64_ty(ctx, 64, signedness::SIGNED), + int128_ty(ctx, 128, signedness::SIGNED), + uint8_ty(ctx, 8, signedness::UNSIGNED), + uint16_ty(ctx, 16, signedness::UNSIGNED), + uint32_ty(ctx, 32, signedness::UNSIGNED), + uint64_ty(ctx, 64, signedness::UNSIGNED){ + +} //===----------------------------------------------------------------------===// // context diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc new file mode 100644 index 000000000..664fbb983 --- /dev/null +++ b/lib/ir/dispatch.cc @@ -0,0 +1,882 @@ +#include "triton/ir/dispatch.h" + +namespace triton { +namespace ir { + + +[[ noreturn ]] void throw_unreachable(std::string key) { + throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " + "This is likely a bug on our side."); +} + +//===----------------------------------------------------------------------===// +// Programming Model +//===----------------------------------------------------------------------===// + +ir::value *dispatch::program_id(int axis, ir::builder *builder) { + return builder->create_get_program_id(axis); +} + +ir::value *dispatch::num_programs(int axis, ir::builder *builder) { + return builder->create_get_num_programs(axis); +} + +//===----------------------------------------------------------------------===// +// Implicit Casting Utilities +//===----------------------------------------------------------------------===// + +ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ + int a_rank = a_ty->get_integer_bitwidth(); + int b_rank = b_ty->get_integer_bitwidth(); + auto a_sn = a_ty->get_integer_signedness(); + auto b_sn = b_ty->get_integer_signedness(); + // Rules for signedness taken from "Usual arithmetic conversions" on + // https://en.cppreference.com/w/c/language/conversion. + if (a_sn == b_sn) { + return a_rank > b_rank ? a_ty : b_ty; + } else if (a_sn == signedness::UNSIGNED) { + return a_rank >= b_rank ? a_ty : b_ty; + } else if (b_sn == signedness::UNSIGNED) { + return b_rank >= a_rank ? b_ty : a_ty; + } else { + throw_unreachable("integer_promote"); + } +} + +enum class DivOrMod { NO, YES }; + +ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { + context &ctx = a_ty->get_context(); + // 1) if one operand is double, the other is implicitly + // converted to double + if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) + return type::get_fp64_ty(ctx); + // 2) if one operand is float, the other is implicitly + // converted to float + if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) + return type::get_fp32_ty(ctx); + // 3 ) if one operand is half, the other is implicitly converted to half + // unless we're doing / or %, which do not exist natively in PTX for fp16. + if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { + if (div_or_mod == DivOrMod::YES) { + return type::get_fp32_ty(ctx); + } else { + return type::get_fp16_ty(ctx); + } + } + if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) + throw_unreachable("computation_type"); + // 4 ) both operands are integer and undergo + // integer promotion + if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { + throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + return integer_promote(a_ty, b_ty); +} + +//===----------------------------------------------------------------------===// +// Binary Operators +//===----------------------------------------------------------------------===// + +void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { + throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); +} + +void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ + + if(type_a->is_pointer_ty()){ + if(!allow_ptr_a) + throw_incompatible_types(type_a, type_b); + // T* + U* with T != U + if(type_b->is_pointer_ty() && (type_a != type_b)) + throw_incompatible_types(type_a, type_b); + // T* + float + if(type_b->is_floating_point_ty()) + throw_incompatible_types(type_a, type_b); + } +} + +void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, + bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, + bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { + // implicit broadcasting + std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); + // implicit typecasting + ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); + ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); + check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); + check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); + if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { + ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); + lhs = dispatch::cast(lhs, ret_sca_ty, builder); + rhs = dispatch::cast(rhs, ret_sca_ty, builder); + } +} + +ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, true); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // offset + ptr + // ptr + offset + if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) + std::swap(input, other); + if (input_scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {other}); + // float + float + else if (input_scalar_ty->is_floating_point_ty()) + return builder->create_fadd(input, other); + // int + int + else if (input_scalar_ty->is_integer_ty()) + return builder->create_add(input, other); + throw_unreachable("add"); +} + +ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, false); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // ptr - offset + if (scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {dispatch::minus(other, builder)}); + // float + float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fsub(input, other); + // int + int + else if (scalar_ty->is_integer_ty()) + return builder->create_sub(input, other); + throw_unreachable("sub"); +} + +ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float * float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fmul(input, other); + // int * int + else if (scalar_ty->is_integer_ty()) + return builder->create_mul(input, other); + throw_unreachable("mul"); +} + +ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // float / int + if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) + other = cast(other, input_scalar_ty, builder); + // int / float + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) + input = cast(input, other_scalar_ty, builder); + // int / int (cast to float32) + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + input = cast(input, builder->get_float_ty(), builder); + other = cast(other, builder->get_float_ty(), builder); + } + // float / float (cast to highest exponent type) + else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ + if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) + other = cast(other, input_scalar_ty, builder); + else + input = cast(input, other_scalar_ty, builder); + } + // unreachable + else + throw_unreachable("div"); + return builder->create_fdiv(input, other); +} + +ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); + input = dispatch::cast(input, ret_ty, builder); + other = dispatch::cast(other, ret_ty, builder); + if (ret_ty->is_integer_signed()) { + return builder->create_sdiv(input, other); + } else { + return builder->create_udiv(input, other); + } + } + throw_unreachable("floordiv"); +} + +ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) + throw semantic_error("both operands of fdiv must have floating point scalar type"); + binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); + ir::value* ret = builder->create_fdiv(input, other); + if(ir::binary_operator* binop = dynamic_cast(ret)) + binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); + return ret; +} + +ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // float % int + if (scalar_ty->is_floating_point_ty()) + return builder->create_frem(input, other); + // int % int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { + throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + if (scalar_ty->is_integer_signed()) { + return builder->create_srem(input, other); + } else { + return builder->create_urem(input, other); + } + } + throw_unreachable("mod"); +} + + +void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, false); + ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); + ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); + if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) + throw_incompatible_types(input_sca_ty, other_sca_ty); + ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); + if (ret_sca_ty != input_sca_ty) + input = dispatch::cast(input, ret_sca_ty, builder); + if (ret_sca_ty != other_sca_ty) + other = dispatch::cast(other, ret_sca_ty, builder); +} + +ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_and(input, other); +} + +ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_or(input, other); +} + + +ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_xor(input, other); +} + + +ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_lshr(input, other); +} + + +ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_shl(input, other); +} + +//===----------------------------------------------------------------------===// +// Unary Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::plus(ir::value *input, ir::builder *) { + return input; +} + +ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty()) + throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); + ir::value *_0 = ir::constant::get_null_value(input_sca_ty); + return dispatch::sub(_0, input, builder); +} + +ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) + throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); + ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); + return dispatch::xor_(input, _1, builder); +} + + +//===----------------------------------------------------------------------===// +// Comparison Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float > float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGT(input, other); + // int > int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGT(input, other); + } else { + return builder->create_icmpUGT(input, other); + } + } + throw_unreachable("greater_than"); +} + +ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float >= float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGE(input, other); + // int >= int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGE(input, other); + } else { + return builder->create_icmpUGE(input, other); + } + } + throw_unreachable("greater_equal"); +} + +ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLT(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLT(input, other); + } else { + return builder->create_icmpULT(input, other); + } + } + throw_unreachable("less_than"); +} + +ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLE(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLE(input, other); + } else { + return builder->create_icmpULE(input, other); + } + } + throw_unreachable("less_equal"); +} + +ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOEQ(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpEQ(input, other); + throw_unreachable("equal"); +} + +ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpUNE(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpNE(input, other); + throw_unreachable("equal"); +} + +//===----------------------------------------------------------------------===// +// Block Creation +//===----------------------------------------------------------------------===// + +ir::value* dispatch::arange(int start, int end, ir::builder *builder) { + return builder->get_range(start, end); +} + +ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { + ir::value *_0 = ir::constant::get_null_value(dtype); + return builder->create_splat(_0, shape); +} + +//===----------------------------------------------------------------------===// +// Shape Manipulation +//===----------------------------------------------------------------------===// + + +ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { + unsigned numel = 1; + for(unsigned s: dst_shape) numel *= s; + if(input->get_type()->get_tile_num_elements() != numel) + throw semantic_error("cannot reshape block of different shape"); + return builder->create_reshape(input, dst_shape); +} + +ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + return builder->create_cat(lhs, rhs); +} + +ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { + if (!input->get_type()->is_block_ty()) + return builder->create_splat(input, shape); + auto src_shape = input->get_type()->get_block_shapes(); + if (src_shape.size() != shape.size()) + throw std::runtime_error("Cannot broadcast"); + if(shape == src_shape) + return input; + return builder->create_broadcast(input, shape); +} + +std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + + // make_shape_compatible(block, scalar) + if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) + rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); + // make_shape_compatible(scalar, block) + else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) + lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); + // make_shape_compatible(block, block) + else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { + auto lhs_shape = lhs_ty->get_block_shapes(); + auto rhs_shape = rhs_ty->get_block_shapes(); + if (lhs_shape.size() != rhs_shape.size()) + throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); + ir::type::block_shapes_t ret_shape; + for (size_t i = 0; i < lhs_shape.size(); ++i) { + unsigned left = lhs_shape[i]; + unsigned right = rhs_shape[i]; + if (left == 1) + ret_shape.push_back(right); + else if (right == 1) + ret_shape.push_back(left); + else if (left == right) + ret_shape.push_back(left); + else + throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + + ": " + std::to_string(left) + " and " + std::to_string(right)); + } + if (lhs_shape != ret_shape) + lhs = builder->create_broadcast(lhs, ret_shape); + if (rhs_shape != ret_shape) + rhs = builder->create_broadcast(rhs, ret_shape); + } + return std::make_tuple(lhs, rhs); +} + +ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty()) + return cast(input, dst_ty, builder); + // Bitcast + int src_bits = src_sca_ty->get_primitive_size_in_bits(); + int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); + if( src_bits!= dst_bits) + throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + + "to data-type of size " + std::to_string(dst_bits)); + return builder->create_cast(ir::BitCast, input, dst_ty); +} + +ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // + if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || + (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ + return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); + } + // FP Truncation + bool truncate_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); + if (truncate_fp) + return builder->create_fp_trunc(input, dst_ty); + // FP Extension + bool ext_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); + if (ext_fp) + return builder->create_fp_ext(input, dst_ty); + // Int cast + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && + (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || + src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { + bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); + return builder->create_int_cast(input, dst_ty, sign_extend); + } + // Float -> Int + if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ + if(dst_sca_ty->is_bool_ty()) + return builder->create_fp_to_ui(input, dst_ty); + else + return builder->create_fp_to_si(input, dst_ty); + } + // int -> Float + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ + if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) + return builder->create_ui_to_fp(input, dst_ty); + else + return builder->create_si_to_fp(input, dst_ty); + } + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ + int bitwidth = dst_sca_ty->get_integer_bitwidth(); + if(bitwidth == 64) + return builder->create_cast(ir::PtrToInt, input, dst_ty); + if(bitwidth == 1) + return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), + builder->get_int64(0), + builder); + } + if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::IntToPtr, input, dst_ty); + // Ptr -> Ptr + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::BitCast, input, dst_ty); + // * -> Bool + if (dst_sca_ty->is_bool_ty()) { + if (src_sca_ty->is_pointer_ty()) + input = cast(input, builder->get_int64_ty(), builder); + ir::value *other = builder->get_int64(0); + if (src_ty->is_bool_ty()) + other = builder->create_splat(other, src_ty->get_block_shapes()); + return builder->create_icmpNE(input, other); + } + throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); +} + +//===----------------------------------------------------------------------===// +// Memory Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask) + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + if(other) + other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); + } + if(other) + other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } + // cache modifier + load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default + if (!cache_modifier.empty()) { + if (cache_modifier == ".ca") + cache = load_inst::CA; + else if (cache_modifier == ".cg") + cache = load_inst::CG; + else + throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); + } + // eviction policy + load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default + if(!eviction_policy.empty()){ + if (eviction_policy == "evict_last") + eviction = load_inst::EVICT_LAST; + else if(eviction_policy == "evict_first") + eviction = load_inst::EVICT_FIRST; + else + throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); + } + + + if (!mask && !other) + return builder->create_load(ptr, cache, eviction, is_volatile); + if (!mask) + throw std::runtime_error("`other` cannot be provided without `mask`"); + auto shape = ptr->get_type()->get_block_shapes(); + if(!other){ + other = ir::undef_value::get(elt_ty); + if(ptr->get_type()->is_block_ty()) + other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); + } + return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); +} + +ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()) + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + if(mask) + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } + // cast to target data-type + val = dispatch::cast(val, elt_ty, builder); + if (!mask) + return builder->create_store(ptr, val); + if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) + throw semantic_error("Mask must have boolean scalar type"); + return builder->create_masked_store(ptr, val, mask); +} + +ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ + return builder->create_atomic_cas(ptr, cmp, val); +} + +void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask){ + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + } + if(val){ + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + } + } + val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + if(!mask){ + mask = builder->get_int1(true); + if(ptr->get_type()->is_block_ty()) + mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); + } +} + +ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_max for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); + } + } + // for float + // return atomic_smax(i_ptr, i_val) if val >= 0 + // return atomic_umin(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_min for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); + } + } + // for float + // return atomic_smin(i_ptr, i_val) if val >= 0 + // return atomic_umax(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; + return builder->create_atomic_rmw(op, ptr, val, mask); +} + +ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); +} + +ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); +} + +ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); +} + +ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); +} + +//===----------------------------------------------------------------------===// +// Linear Algebra +//===----------------------------------------------------------------------===// + +ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { + ir::value *_0 = nullptr; + if (lhs->get_type()->is_int_or_tileint_ty()) + _0 = builder->get_int32(0); + else + _0 = builder->get_float32(0); + unsigned M = lhs->get_type()->get_block_shapes()[0]; + unsigned N = rhs->get_type()->get_block_shapes()[1]; + _0 = builder->create_splat(_0, {M, N}); + bool _allow_tf32 = allow_tf32->get_value() != 0; + return builder->create_dot(lhs, rhs, _0, _allow_tf32); +} + + +//===----------------------------------------------------------------------===// +// Indexing +//===----------------------------------------------------------------------===// + +ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ + condition = dispatch::cast(condition, builder->get_int1_ty(), builder); + if(condition->get_type()->is_block_ty()){ + x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); + y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); + } + ir::type* x_ty = x->get_type()->get_scalar_ty(); + ir::type* y_ty = y->get_type()->get_scalar_ty(); + ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); + x = dispatch::cast(x, ty, builder); + y = dispatch::cast(y, ty, builder); + return builder->create_select(condition, x, y); +} + + +//===----------------------------------------------------------------------===// +// Reductions +//===----------------------------------------------------------------------===// + +ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, + ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // input is extended to 32-bits if necessary + // this increases numerical accuracy and can be done pretty much for free + // on GPUs + if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) + input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); + if (scalar_ty->is_floating_point_ty()) + return builder->create_reduce(input, FLOAT_OP, axis); + else if (scalar_ty->is_integer_ty()) + return builder->create_reduce(input, INT_OP, axis); + throw_unreachable(name); +} + +ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); +} + +ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); +} + +ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); +} + +ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (!scalar_ty->is_integer_ty()) + throw semantic_error("xor_sum only supported for integers"); + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); +} + + +//===----------------------------------------------------------------------===// +// Math +//===----------------------------------------------------------------------===// + +ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { + binary_op_type_checking(x, y, builder); + return builder->insert(umulhi_inst::create(x, y)); +} + +ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { + return builder->create_exp(x); +} + +ir::value *dispatch::log(ir::value *x, ir::builder *builder) { + return builder->create_log(x); +} + +ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { + return builder->create_cos(x); +} + +ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { + return builder->create_sin(x); +} + +ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { + return builder->create_sqrt(x); +} + + +// + +ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("multiple_of"); + i->set_metadata(ir::metadata::multiple_of, value); + return i; +} + +ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("max_contiguous"); + i->set_metadata(ir::metadata::max_contiguous, value); + return i; +} + +ir::value *dispatch::debug_barrier(ir::builder *builder) { + return builder->create_barrier(); +} + + +} +} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 39bd945bc..c225b315f 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -312,8 +312,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); + (arg_bits > dst_bits ? cast_op_t::Trunc : + (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); return create(op, arg, ty, name, next); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index a37d3048f..33b39de3a 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,6 +9,146 @@ namespace triton{ namespace ir{ +/* Module */ +module::module(const std::string &name, builder &builder) + : name_(name), builder_(builder) { + sealed_blocks_.insert(nullptr); +} + +ir::builder& module::get_builder() { + return builder_; +} + +void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ + values_[val_key_t{name, block}] = value; + auto it = metadatas_.find(name); + if(auto *x = dynamic_cast(value)) + if(it != metadatas_.end()){ + x->set_metadata(it->second.first, it->second.second); + } +// value->set_name(name); +} + +void module::set_value(const std::string& name, ir::value *value){ + return set_value(name, builder_.get_insert_block(), value); +} + +void module::set_const(const std::string& name){ + const_.insert(name); +} + +void module::set_continue_fn(std::function fn) { + continue_fn_ = fn; +} + +std::function module::get_continue_fn() { + return continue_fn_; +} + +ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ + basic_block::iterator insert = block->get_first_non_phi(); + if(insert != block->end()){ + builder_.set_insert_point(insert); + } + ir::phi_node *res = builder_.create_phi(ty, num_values); + if(insert != block->end()) + builder_.set_insert_point(block); + return res; +} + +ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ + // find non-self references + std::set non_self_ref; + std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), + [phi](ir::value* op){ return op != phi && op; }); + // non-trivial + if(non_self_ref.size() != 1) + return phi; + // unique value or self-reference + ir::value *same = *non_self_ref.begin(); + assert(same != nullptr); + phi->replace_all_uses_with(same); + phi->erase_from_parent(); + std::set users = phi->get_users(); + for(ir::user* u: users) + if(auto *uphi = dynamic_cast(u)) + if(uphi != phi) + try_remove_trivial_phis(uphi); + return same; +} + + +ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ + // already initialized + if(phi->get_num_operands()) + return phi; + ir::basic_block *block = phi->get_parent(); + for(ir::basic_block *pred: block->get_predecessors()){ + ir::value *value = get_value(name, pred); + phi->add_incoming(value, pred); + } + return phi; +} + +ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { + ir::value *result; + bool is_const = const_.find(name) != const_.end(); + auto &preds = block->get_predecessors(); + ir::type *ty = types_.at(name); + if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ + incomplete_phis_[block][name] = make_phi(ty, 1, block); + result = (ir::value*)incomplete_phis_[block][name]; + } + else if(preds.size() <= 1){ + bool has_pred = preds.size(); + result = get_value(name, has_pred?preds.front():nullptr); + } + else{ + ir::phi_node* phi = make_phi(ty, 1, block); + set_value(name, block, phi); + result = add_phi_operands(name, phi); + if(auto *phi = dynamic_cast(result)) + result = try_remove_trivial_phis(phi); + } + if(auto *phi = dynamic_cast(result)){ + result = try_remove_trivial_phis(phi); + } + set_value(name, block, result); + return result; +} + +ir::value *module::get_value(const std::string& name, ir::basic_block *block) { + ir::basic_block* save_block = builder_.get_insert_block(); + ir::basic_block::iterator save_pt = builder_.get_insert_point(); + val_key_t key(name, block); + if(values_.find(key) != values_.end()){ + return values_.at(key); + } + ir::value *result = get_value_recursive(name, block); + builder_.set_insert_point(save_block); + if(save_pt != save_block->end()) + builder_.set_insert_point(save_pt); + return result; +} + +ir::value *module::get_value(const std::string& name) { + return get_value(name, builder_.get_insert_block()); +} + +const std::string& module::get_name() { + return name_; +} + +void module::seal_block(ir::basic_block *block){ + for(auto &x: incomplete_phis_[block]){ + add_phi_operands(x.first, x.second); + if(get_value(x.first) == x.second) + set_value(x.first, try_remove_trivial_phis(x.second)); + } + sealed_blocks_.insert(block); + incomplete_phis_[block].clear(); +} + /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 056ae99e6..7e4e4e5d7 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } +signedness type::get_integer_signedness() const +{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } + +bool type::is_integer_signed() const { + if (id_ != IntegerTyID) { + throw std::logic_error("type is " + repr() + ", not integer"); + } + return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; +} + unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } +integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } +integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } +integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } +integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index b66761ec3..9e53cc341 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,6 +3,7 @@ #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" +#include "triton/ir/dispatch.h" #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -11,12 +12,10 @@ #include #include #include -#include #include #include "Python.h" #include #include -#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -542,6 +541,84 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); } +/*****************************************************************************/ +/* User-facing language features */ +/*****************************************************************************/ + +void init_triton_frontend(py::module &&m) { + using ret = py::return_value_policy; + + // programming model + m.def("program_id", &ir::dispatch::program_id, ret::reference); + m.def("num_programs", &ir::dispatch::num_programs, ret::reference); + // binary + m.def("add", &ir::dispatch::add, ret::reference); + m.def("sub", &ir::dispatch::sub, ret::reference); + m.def("mul", &ir::dispatch::mul, ret::reference); + m.def("truediv", &ir::dispatch::truediv, ret::reference); + m.def("floordiv", &ir::dispatch::floordiv, ret::reference); + m.def("fdiv", &ir::dispatch::fdiv, ret::reference); + m.def("mod", &ir::dispatch::mod, ret::reference); + m.def("and_", &ir::dispatch::and_, ret::reference); + m.def("or_", &ir::dispatch::or_, ret::reference); + m.def("xor_", &ir::dispatch::xor_, ret::reference); + m.def("lshr", &ir::dispatch::lshr, ret::reference); + m.def("shl", &ir::dispatch::shl, ret::reference); + // unary + m.def("plus", &ir::dispatch::plus, ret::reference); + m.def("minus", &ir::dispatch::minus, ret::reference); + m.def("invert", &ir::dispatch::invert, ret::reference); + // comparison + m.def("greater_than", &ir::dispatch::greater_than, ret::reference); + m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); + m.def("less_than", &ir::dispatch::less_than, ret::reference); + m.def("less_equal", &ir::dispatch::less_equal, ret::reference); + m.def("equal", &ir::dispatch::equal, ret::reference); + m.def("not_equal", &ir::dispatch::not_equal, ret::reference); + // block creation + m.def("arange", &ir::dispatch::arange, ret::reference); + m.def("zeros", &ir::dispatch::zeros, ret::reference); + // type manipuatation + m.def("cat", &ir::dispatch::cat, ret::reference); + m.def("reshape", &ir::dispatch::reshape, ret::reference); + typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); + typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); + m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("bitcast", &ir::dispatch::bitcast, ret::reference); + m.def("cast", &ir::dispatch::cast, ret::reference); + // memory + m.def("load", &ir::dispatch::load, ret::reference); + m.def("store", &ir::dispatch::store, ret::reference); + m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); + m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); + m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); + m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); + m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); + m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); + m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); + m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); + // linear algebra + m.def("dot", &ir::dispatch::dot, ret::reference); + // indexing + m.def("where", &ir::dispatch::where, ret::reference); + // reduction + m.def("min", &ir::dispatch::min, ret::reference); + m.def("max", &ir::dispatch::max, ret::reference); + m.def("sum", &ir::dispatch::sum, ret::reference); + m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); + // math + m.def("umulhi", &ir::dispatch::umulhi, ret::reference); + m.def("exp", &ir::dispatch::exp, ret::reference); + m.def("log", &ir::dispatch::log, ret::reference); + m.def("cos", &ir::dispatch::cos, ret::reference); + m.def("sin", &ir::dispatch::sin, ret::reference); + m.def("sqrt", &ir::dispatch::sqrt, ret::reference); + // internal (debugging only) + m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); + m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); + m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); +} /*****************************************************************************/ /* Python bindings for triton::ir */ @@ -551,86 +628,16 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; - py::enum_(m, "CACHE_MODIFIER") - .value("NONE", ir::load_inst::NONE) - .value("CA", ir::load_inst::CA) - .value("CG", ir::load_inst::CG) - .export_values(); - - py::enum_(m, "EVICTION_POLICY") - .value("NORMAL", ir::load_inst::NORMAL) - .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) - .value("EVICT_LAST", ir::load_inst::EVICT_LAST) - .export_values(); - - py::enum_(m, "REDUCE_OP") - .value("ADD", ir::reduce_inst::ADD) - .value("FADD", ir::reduce_inst::FADD) - .value("MIN", ir::reduce_inst::MIN) - .value("MAX", ir::reduce_inst::MAX) - .value("FMIN", ir::reduce_inst::FMIN) - .value("FMAX", ir::reduce_inst::FMAX) - .value("XOR", ir::reduce_inst::XOR); - - py::enum_(m, "ATOMIC_OP") - .value("ADD", ir::atomic_rmw_op_t::Add) - .value("FADD", ir::atomic_rmw_op_t::FAdd) - .value("AND", ir::atomic_rmw_op_t::And) - .value("OR", ir::atomic_rmw_op_t::Or) - .value("XOR", ir::atomic_rmw_op_t::Xor) - .value("XCHG", ir::atomic_rmw_op_t::Xchg) - .value("MAX", ir::atomic_rmw_op_t::Max) - .value("MIN", ir::atomic_rmw_op_t::Min) - .value("UMIN", ir::atomic_rmw_op_t::UMin) - .value("UMAX", ir::atomic_rmw_op_t::UMax); - py::class_(m, "context") .def(py::init<>()); - py::class_(m, "value") - .def("multiple_of", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::multiple_of, val); - } else - throw std::runtime_error("multiple_of"); - }) - .def("max_contiguous", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::max_contiguous, val); - } else - throw std::runtime_error("max_contiguous"); - }) - .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { - if (auto *instr = dynamic_cast(self)) - instr->set_fdiv_ieee_rounding(val); - else - throw std::runtime_error("set_fdiv_ieee_rounding"); - }) - .def("is_phi", [](ir::value *self) { - if (auto *pn = dynamic_cast(self)) - return true; - return false; - }) - .def("ops", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) { - return instr->ops(); - } - throw std::runtime_error("cannot use ops()"); - }) - .def("replace_all_uses_with", &ir::value::replace_all_uses_with) - .def("erase_from_parent", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) - return instr->erase_from_parent(); - throw std::runtime_error("cannot use erase_from_parent"); - }) - .def_property("name", &ir::value::get_name, &ir::value::set_name) - .def_property_readonly("type", &ir::value::get_type); + auto value = py::class_(m, "value"); + value.def_property("name", &ir::value::get_name, &ir::value::set_name); + value.def_property_readonly("type", &ir::value::get_type); py::class_(m, "user"); - py::class_(m, "constant") - .def("get_null_value", &ir::constant::get_null_value, ret::reference) - .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); + py::class_(m, "constant"); py::class_(m, "undef") .def("get", &ir::undef_value::get, ret::reference); @@ -641,17 +648,16 @@ void init_triton_ir(py::module &&m) { .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value) - .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); + .def_property_readonly("value", &ir::constant_fp::get_value); - py::class_(m, "instruction") - .def("get_parent", [](ir::instruction *self) { - return self->get_parent(); - }, ret::reference); - py::class_(m, "phi_node") - .def("add_incoming", &ir::phi_node::add_incoming); + py::class_(m, "instruction"); + py::class_(m, "phi_node"); py::class_(m, "type") + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) @@ -666,38 +672,34 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) + .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) + .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) + .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) + .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) - .def("get_block_shapes", &ir::type::get_block_shapes) - - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) .def("is_void", &ir::type::is_void_ty) - .def("is_bool", &ir::type::is_bool_ty) .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_fp16_ty) .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) - .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) + .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) + .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) + .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) + .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference) - .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) - .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); + .def_property_readonly("context", &ir::type::get_context, ret::reference); py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) - .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); py::class_(m, "function_type"); py::class_(m, "integer_type"); @@ -707,15 +709,16 @@ void init_triton_ir(py::module &&m) { py::class_(m, "module") .def(py::init()) - .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { - const auto metadatas = self->get_metadatas(); - auto it = metadatas.find(name); - if (it != metadatas.end()) - if (auto *instr = dynamic_cast(value)) { - instr->set_metadata(it->second.first, it->second.second); - } - }) - .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); + .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) + .def("seal_block", &ir::module::seal_block) + .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) + .def("set_type", &ir::module::set_type) + .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) + .def("get_values", &ir::module::get_values, ret::reference) + .def("set_values", &ir::module::set_values) + .def("get_types", &ir::module::get_types, ret::reference) + .def("set_types", &ir::module::set_types) + .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; py::enum_(m, "attribute_kind") @@ -739,13 +742,6 @@ void init_triton_ir(py::module &&m) { py::class_(m, "basic_block") .def("create", &ir::basic_block::create, ret::reference) - .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) - .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { - ir::basic_block::iterator it = self->get_first_non_phi(); - if (it == self->end()) - return nullptr; - return *it; - }, ret::reference) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "builder", py::dynamic_attr()) @@ -756,162 +752,17 @@ void init_triton_ir(py::module &&m) { .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) - // insertion block/point, insert points are represented as (*bb, *instr) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) - .def("get_insert_point", [](ir::builder *self) { - ir::basic_block *bb = self->get_insert_block(); - ir::basic_block::iterator it = self->get_insert_point(); - ir::instruction *instr = it == bb->end() ? nullptr : *it; - return std::make_pair(bb, instr); - }, ret::reference) - .def("set_insert_point", [](ir::builder *self, std::pair pt) { - ir::basic_block *bb = pt.first; - ir::instruction *instr = pt.second; - if (instr) { - if (bb != instr->get_parent()) - throw std::runtime_error("invalid insertion point, instr not in bb"); - self->set_insert_point(instr); - } else { - assert(bb); - self->set_insert_point(bb); - } - }) - // Constants + // constants .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) - .def("get_uint32", &ir::builder::get_int32, ret::reference) - .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) - .def("get_uint64", &ir::builder::get_int64, ret::reference) + .def("get_int32", &ir::builder::get_int32, ret::reference) + .def("get_int64", &ir::builder::get_int64, ret::reference) + .def("get_uint32", &ir::builder::get_uint32, ret::reference) + .def("get_uint64", &ir::builder::get_uint64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference) - // Types - .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) - .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference) - .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference) - .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) - .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference) - .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) - .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) - .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) - .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) - .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) - .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) - // terminator instructions - .def("create_br", &ir::builder::create_br, ret::reference) - .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) - .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) - // Cast instructions - .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) - .def("create_cast", &ir::builder::create_cast, ret::reference) - .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) - .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) - .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) - .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) - .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) - .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) - .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) - .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) - .def("create_downcast", &ir::builder::create_downcast, ret::reference) - // phi - .def("create_phi", &ir::builder::create_phi, ret::reference) - // Binary instructions - .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) - .def("create_fmul", &ir::builder::create_fmul, ret::reference) - .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) - .def("create_frem", &ir::builder::create_frem, ret::reference) - .def("create_fadd", &ir::builder::create_fadd, ret::reference) - .def("create_fsub", &ir::builder::create_fsub, ret::reference) - .def("create_mul", &ir::builder::create_mul, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) - .def("create_udiv", &ir::builder::create_udiv, ret::reference) - .def("create_srem", &ir::builder::create_srem, ret::reference) - .def("create_urem", &ir::builder::create_urem, ret::reference) - .def("create_add", &ir::builder::create_add, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sub", &ir::builder::create_sub, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_shl", &ir::builder::create_shl, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_lshr", &ir::builder::create_lshr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_ashr", &ir::builder::create_ashr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - // GEP - .def("create_gep", &ir::builder::create_gep, ret::reference) - // Comparison (int) - .def("create_icmp", &ir::builder::create_icmp, ret::reference) - .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) - .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) - .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) - .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) - .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) - .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) - .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) - .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) - .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) - .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) - // Comparison (float) - .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) - .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) - .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) - .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) - .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) - .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) - .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) - .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) - .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) - .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) - .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) - .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) - .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) - // Logical - .def("create_and", &ir::builder::create_and, ret::reference) - .def("create_xor", &ir::builder::create_xor, ret::reference) - .def("create_or", &ir::builder::create_or, ret::reference) - // Input/Output - .def("create_load", &ir::builder::create_load, ret::reference) - .def("create_store", &ir::builder::create_store, ret::reference) - .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) - .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) - // Block instruction - .def("create_splat", &ir::builder::create_splat, ret::reference) - .def("create_reshape", &ir::builder::create_reshape, ret::reference) - .def("create_cat", &ir::builder::create_cat, ret::reference) - .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) - // atomic - .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) - .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) - - // Built-in instruction - .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) - .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) - .def("create_exp", &ir::builder::create_exp, ret::reference) - .def("create_cos", &ir::builder::create_cos, ret::reference) - .def("create_sin", &ir::builder::create_sin, ret::reference) - .def("create_log", &ir::builder::create_log, ret::reference) - .def("create_dot", &ir::builder::create_dot, ret::reference) - .def("create_trans", &ir::builder::create_trans, ret::reference) - .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) - .def("create_reduce", &ir::builder::create_reduce, ret::reference) - .def("create_select", &ir::builder::create_select, ret::reference) - // Intrinsics - // These have no place in the IR, and hopefully they can be removed at some point - .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) - .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) - .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) - .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) - .def("create_barrier", &ir::builder::create_barrier, ret::reference) - .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) - .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); + .def("get_range", &ir::builder::get_range, ret::reference); } void init_triton(py::module &m) { @@ -919,4 +770,5 @@ void init_triton(py::module &m) { init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); + init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index f30b203bb..1df3a0b49 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -37,7 +37,7 @@ matmul_data = { (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.695}, + (2048, 2048, 2048): {'float16': 0.680}, (4096, 4096, 4096): {'float16': 0.831}, (8192, 8192, 8192): {'float16': 0.849}, # tall-skinny diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3561f7af4..a49b47585 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,4 +1,5 @@ # flake8: noqa: F821,F841 +import copy import itertools import re from typing import Optional, Union @@ -584,6 +585,7 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -991,6 +993,27 @@ def test_noop(device='cuda'): kernel[(1, )](x) +@pytest.mark.parametrize("value, value_type", [ + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + pgm = kernel[(1, )](value, x) + + # Parse out the type of the 'VALUE' parameter from the Triton IR. + triton_ir = pgm.asm['ttir'] + ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + assert ir_value_type == value_type + + @pytest.mark.parametrize( "value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d866d6983..8ac01bcc8 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,5 +1,4 @@ import os -import re import shutil import pytest @@ -103,30 +102,3 @@ def test_specialize(mode): for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target - - -@pytest.mark.parametrize("value, value_type", [ - (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), - (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), - (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') -]) -def test_value_specialization(value: int, value_type: str, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - cache_str = None - - def get_cache_str(*args, **kwargs): - nonlocal cache_str - cache_str = kwargs['key'].split('-') - triton.code_gen.JITFunction.cache_hook = get_cache_str - reset_tmp_dir() - x = torch.tensor([3.14159], device='cuda') - kernel[(1, )](value, x) - triton.code_gen.JITFunction.cache_hook = None - - cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) - spec_type = None if cache_str_match is None else cache_str_match.group(1) - assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 37ba46efc..f9982939c 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,8 +6,7 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - JITFunction, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 23d460f29..cb705aaa6 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import ast import builtins import functools @@ -13,7 +11,7 @@ import tempfile import textwrap import time import warnings -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict import torch from filelock import FileLock @@ -24,13 +22,48 @@ from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): + def get_value(self, name): + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + if isinstance(ret, triton.language.block): + handle = self.module.get_value(name) + return triton.language.block(handle) + return ret + + def set_value(self, name, value): + if isinstance(value, _triton.ir.value): + value = triton.language.block(value) + if isinstance(value, triton.language.block): + self.module.set_value(name, value.handle) + self.module.set_type(name, value.handle.type) + self.lscope[name] = value + + def is_triton_object(self, value): + return isinstance(value, triton.language.block) + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.last_ret = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + return stmts and isinstance(stmt, ast.Return) + def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) self.prototype = prototype self.gscope = gscope self.lscope = dict() - self.is_arg_lscope = dict() # name => is_arg: {str: bool} self.attributes = attributes self.constants = constants self.kwargs = kwargs @@ -44,146 +77,6 @@ class CodeGenerator(ast.NodeVisitor): 'isinstance': isinstance, 'getattr': getattr, } - # SSA-construction - # [name, bb] => triton.language.tensor - self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} - # bb => {name => phi} - self.incomplete_phis = {} - self.sealed_blocks: Set[_triton.ir.basic_block] = set() - - def get_value(self, name): - ''' This function: - 1. make sure `name` is defined - 2. if `name` is triton.language.tensor, get stored tensor by calling - `self._get_tensor()` - ''' - # search node.id in local scope - ret = None - if name in self.lscope: - ret = self.lscope[name] - # search node.id in global scope - elif name in self.gscope: - ret = self.gscope[name] - # search node.id in builtins - elif name in self.builtins: - ret = self.builtins[name] - else: - raise ValueError(f'{name} is not defined') - if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]: - return self._get_tensor(name) - return ret - - def set_value(self, name: str, - value: Union[triton.language.tensor, triton.language.constexpr], - is_arg: bool = False) -> None: - ''' This function: - called by visit_Assign() & visit_FuncDef() to store left value (lvalue) - 1. record local defined name (FIXME: should consider control flow) - 2. store tensor in self.lvalue - ''' - self.lscope[name] = value - # if this value is an argument, we don't need to create phis for it - self.is_arg_lscope[name] = is_arg - if isinstance(value, triton.language.tensor) and not is_arg: - self._set_value(name, self.builder.get_insert_block(), value) - - # - # SSA-construction - # - def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: - if not bb: - bb = self.builder.get_insert_block() - # local value numbering - if (name, bb) in self.lvalues: - return self.lvalues[(name, bb)] - # global value numbering - saved_insert_point = self.builder.get_insert_point() - result = self._get_tensor_recursive(name, bb) - self.builder.set_insert_point(saved_insert_point) - return result - - def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: - preds = bb.get_predecessors() - type = self.lscope[name].type - # some preds haven't been filled, create a phi as a proxy of the value - if bb not in self.sealed_blocks: - result = self._make_phi(type, len(preds), bb) - if bb in self.incomplete_phis: - self.incomplete_phis[bb][name] = result - else: - self.incomplete_phis[bb] = {name: result} - elif len(preds) == 1: - # one predecessor: no phi needed, try get value from pred - result = self._get_tensor(name, preds[0]) - else: # multiple preds - assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)' - phi = self._make_phi(type, len(preds), bb) - self._set_value(name, bb, phi) - result = self._add_phi_operands(name, phi) - self._set_value(name, bb, result) - return result - - # returns a new phi tensor, which encausulate an ir.phi_node - def _make_phi(self, - type: triton.language.dtype, - num_values: int, - bb: _triton.ir.basic_block) -> triton.language.tensor: - instr = bb.get_first_non_phi() - self.builder.set_insert_point((bb, instr)) - ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) - if instr: - self.builder.set_insert_block(bb) - return triton.language.tensor(ir_phi, type) - - # complete a phi node. (TODO: rename this as _complete_phis?) - # Note: since we try to remove tryival phi, the return tensor might not be a phi - def _add_phi_operands(self, name: str, - phi: triton.language.tensor) -> triton.language.tensor: - bb = phi.handle.get_parent() - for pred in bb.get_predecessors(): - v = self._get_tensor(name, pred) - phi.handle.add_incoming(v.handle, pred) - phi = self._try_remove_trivial_phi(phi) - return phi - - def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: - self.lvalues[(name, bb)] = value - # TODO: why we need this? - self.module.set_instr_metadata(name, value.handle) - - def _seal_block(self, bb: _triton.ir.basic_block): - # complete all incomplete phis - if bb in self.incomplete_phis: - for name, phi in self.incomplete_phis[bb].items(): - result = self._add_phi_operands(name, phi) - # it's possible that this phi is trivial - if self._get_tensor(name, bb).handle == phi.handle: - self._set_value(name, bb, result) - del self.incomplete_phis[bb] - self.sealed_blocks.add(bb) - - def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: - unique_handles = {op for op in phi.handle.ops() if op != phi.handle} - if len(unique_handles) != 1: # non-trivial phi - return phi - v = unique_handles.pop() - phi.handle.replace_all_uses_with(v) - phi.handle.erase_from_parent() - # TODO: remove trivial phis recursively - return triton.language.tensor(v, phi.type) - - def is_triton_tensor(self, value): - return isinstance(value, triton.language.tensor) - - # - # AST visitor - # - def visit_compound_statement(self, stmts): - for stmt in stmts: - self.last_ret = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -220,7 +113,7 @@ class CodeGenerator(ast.NodeVisitor): if inline: pass else: - fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) + fn = self.module.get_or_insert_function(node.name, self.prototype) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -237,17 +130,17 @@ class CodeGenerator(ast.NodeVisitor): attr = _triton.ir.attribute(attr, self.attributes[i]) fn.add_attr(idx + 1, attr) fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) + arg_values.append(fn.args[idx]) idx += 1 for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value, is_arg=True) + self.set_value(arg_name, arg_value) if inline: self.visit_compound_statement(node.body) return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self._seal_block(entry) + self.module.seal_block(entry) self.builder.set_insert_block(entry) # visit function body self.visit_compound_statement(node.body) @@ -294,12 +187,11 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): - # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value - if not isinstance(value, triton.language.tensor): - value = triton.language.core._to_tensor(value, self.builder) + if not isinstance(value, triton.language.block): + value = triton.language.core._to_ir(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -328,9 +220,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): + if isinstance(lhs, triton.language.core.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): + if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value fn = { ast.Add: '__add__', @@ -346,9 +238,9 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -356,15 +248,15 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if isinstance(cond, triton.language.tensor): + if isinstance(cond, triton.language.block): cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self._seal_block(then_bb) + self.module.seal_block(then_bb) if else_bb: - self._seal_block(else_bb) + self.module.seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -379,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self._seal_block(endif_bb) + self.module.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -404,9 +296,9 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.constexpr): + if isinstance(lhs, triton.language.core.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): + if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) @@ -420,9 +312,9 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -433,21 +325,21 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): + if isinstance(op, triton.language.core.constexpr): op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if self.is_triton_tensor(op): + if self.is_triton_object(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) def continue_fn(): cond = self.visit(node.test) @@ -458,9 +350,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -470,7 +362,7 @@ class CodeGenerator(ast.NodeVisitor): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] @@ -513,8 +405,8 @@ class CodeGenerator(ast.NodeVisitor): step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) def continue_fn(): self.visit(step_node) @@ -529,9 +421,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -559,7 +451,7 @@ class CodeGenerator(ast.NodeVisitor): args = [self.visit(arg) for arg in node.args] if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) - if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ + if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): @@ -699,7 +591,7 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] - if isinstance(obj, triton.language.constexpr): + if isinstance(obj, triton.language.core.constexpr): obj = obj.value if isinstance(obj, int): if -2**31 <= obj < 2**31: @@ -731,34 +623,34 @@ class Kernel: return 'scalar', name @staticmethod - def _to_triton_ir(obj): + def _to_triton_ir(context, obj): which, name = obj type_map = { - 'I': triton.language.int32, - 'L': triton.language.int64, - 'f': triton.language.float32, - 'B': triton.language.int1, - 'f8': triton.language.float8, - 'f16': triton.language.float16, - 'bf16': triton.language.bfloat16, - 'f32': triton.language.float32, - 'f64': triton.language.float64, - 'i1': triton.language.int1, - 'i8': triton.language.int8, - 'i16': triton.language.int16, - 'i32': triton.language.int32, - 'i64': triton.language.int64, - 'u8': triton.language.uint8, - 'u16': triton.language.uint16, - 'u32': triton.language.uint32, - 'u64': triton.language.uint64, + 'I': _triton.ir.type.get_int32, + 'L': _triton.ir.type.get_int64, + 'f': _triton.ir.type.get_fp32, + 'B': _triton.ir.type.get_int1, + 'f8': _triton.ir.type.get_fp8, + 'f16': _triton.ir.type.get_fp16, + 'bf16': _triton.ir.type.get_bf16, + 'f32': _triton.ir.type.get_fp32, + 'f64': _triton.ir.type.get_fp64, + 'i1': _triton.ir.type.get_int1, + 'i8': _triton.ir.type.get_int8, + 'i16': _triton.ir.type.get_int16, + 'i32': _triton.ir.type.get_int32, + 'i64': _triton.ir.type.get_int64, + 'u8': _triton.ir.type.get_uint8, + 'u16': _triton.ir.type.get_uint16, + 'u32': _triton.ir.type.get_uint32, + 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers if which == 'ptr': - elt_ty = type_map[name] - return triton.language.pointer_type(elt_ty, 1) + elt_ty = type_map[name](context) + return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - return type_map[name] + return type_map[name](context) @staticmethod def pow2_divisor(N): @@ -1038,31 +930,25 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - # Called by CodeGenerator.visit_Call() def __call__(self, *args, generator: CodeGenerator, **kwargs): try: from inspect import getcallargs arg_values = getcallargs(self.fn, *args, **kwargs) arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.tensor) + arg_values = [arg if isinstance(arg, triton.language.block) else triton.language.constexpr(arg) for arg in arg_values] - # Record values in the caller (parent scope) gscope = generator.gscope.copy() lscope = generator.lscope.copy() - - # TODO: clear values other than args - lvalues = generator.lvalues.copy() - # types = generator.module.get_types().copy() + values = generator.module.get_values().copy() + types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope - - generator.lvalues = lvalues - # generator.module.set_types(types) - + generator.module.set_values(values) + generator.module.set_types(types) return ret except Exception as e: node = generator.last_node @@ -1147,9 +1033,9 @@ class JITFunction: # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) + arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] + ret_type = _triton.ir.type.get_void(context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 81b9fe790..df25e59fb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,36 +1,63 @@ -from __future__ import annotations - -from enum import Enum from functools import wraps -from typing import List import triton -from . import semantic -from triton._C.libtriton.triton import ir +from triton._C.libtriton.triton import frontend, ir -def _to_tensor(x, builder): +# convert block/dtype to ir values +def _to_ir(x, builder): if isinstance(x, bool): - return tensor(builder.get_int1(x), int1) - # Note: compile-time const integers are represented by unsigned values + return builder.get_int1(x) elif isinstance(x, int): if -2**31 <= x < 2**31: - return tensor(builder.get_int32(x), int32) + return builder.get_int32(x) elif 2**31 <= x < 2**32: - return tensor(builder.get_uint32(x), uint32) + return builder.get_uint32(x) elif -2**63 <= x < 2**63: - return tensor(builder.get_int64(x), int64) + return builder.get_int64(x) elif 2**63 <= x < 2**64: - return tensor(builder.get_uint64(x), uint64) + return builder.get_uint64(x) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - return tensor(builder.get_float32(x), float32) + return builder.get_float32(x) elif isinstance(x, constexpr): - return _to_tensor(x.value, builder) - elif isinstance(x, tensor): + return _to_ir(x.value, builder) + elif isinstance(x, block): + return x.handle + elif isinstance(x, dtype): + return x.handle(builder) + return x + + +def _patch(fn): + def _from_ir(x): + if isinstance(x, ir.value): + if x.type.is_void(): + return None + return block(x) return x - assert False, f'cannot convert {x} to tensor' + + def wrapper(*args, **kwargs): + builder = args[-1] + assert isinstance(builder, ir.builder) + args = [_to_ir(x, builder) for x in args] + # for i, arg in enumerate(args): + # if arg is None: + # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") + kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} + ret = fn(*args, **kwargs) + if isinstance(ret, tuple): + return map(_from_ir, ret) + return _from_ir(ret) + + return wrapper + + +for name in dir(frontend): + fn = getattr(frontend, name) + if callable(fn): + setattr(frontend, name, _patch(fn)) def builtin(fn): @@ -45,147 +72,20 @@ def builtin(fn): class dtype: - SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] - UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] - FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] - OTHER_TYPES = ['void'] - - class SIGNEDNESS(Enum): - SIGNED = 0 - UNSIGNED = 1 - - def __init__(self, name): - self.name = name - assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name - if name in dtype.SINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.SIGNED - self.int_bitwidth = int(name.split('int')[-1]) - self.primitive_bitwidth = self.int_bitwidth - elif name in dtype.UINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.UNSIGNED - self.int_bitwidth = int(name.split('int')[-1]) - self.primitive_bitwidth = self.int_bitwidth - elif name in dtype.FP_TYPES: - if name == 'fp8': - self.fp_mantissa_width = 3 - self.primitive_bitwidth = 8 - elif name == 'fp16': - self.fp_mantissa_width = 10 - self.primitive_bitwidth = 16 - elif name == 'bf16': - self.fp_mantissa_width = 7 - self.primitive_bitwidth = 16 - elif name == 'fp32': - self.fp_mantissa_width = 23 - self.primitive_bitwidth = 32 - elif name == 'fp64': - self.fp_mantissa_width = 53 - self.primitive_bitwidth = 64 - elif name == 'void': - self.primitive_bitwidth = 0 - - def is_fp8(self): - return self.name == 'fp8' - - def is_fp16(self): - return self.name == 'fp16' - - def is_bf16(self): - return self.name == 'bf16' - - def is_fp32(self): - return self.name == 'fp32' - - def is_fp64(self): - return self.name == 'fp64' - - def is_int1(self): - return self.name == 'int1' - - def is_int8(self): - return self.name == 'int8' - - def is_int16(self): - return self.name == 'int16' - - def is_int32(self): - return self.name == 'int32' - - def is_int64(self): - return self.name == 'int64' - - def is_uint8(self): - return self.name == 'uint8' - - def is_uint16(self): - return self.name == 'uint16' - - def is_uint32(self): - return self.name == 'uint32' - - def is_uint64(self): - return self.name == 'uint64' - - def is_floating(self): - return self.name in dtype.FP_TYPES - - def is_int_signed(self): - return self.name in dtype.SINT_TYPES - - def is_int(self): - return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES - - def is_bool(self): - return self.is_int1() - - def is_void(self): - raise RuntimeError("Not implemented") - - def is_block(self): - return False - - def is_ptr(self): - return False - - def __eq__(self, other: dtype): - if not isinstance(other, dtype): - return False - return self.name == other.name - - def __ne__(self, other: dtype): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.name,)) + def __init__(self, init): + self.init = init @property - def scalar(self): - return self + def name(self) -> str: + # The init functions are named something like 'get_int8'. Strip the prefix. + nom = self.init.__name__ + prefix = 'get_' + assert nom.startswith(prefix) + return nom[len(prefix):] - def to_ir(self, builder: ir.builder) -> ir.type: - if self.name == 'void': - return builder.get_void_ty() - elif self.name == 'int1': - return builder.get_int1_ty() - elif self.name == 'int8' or self.name == 'uint8': - return builder.get_int8_ty() - elif self.name == 'int16' or self.name == 'uint16': - return builder.get_int16_ty() - elif self.name == 'int32' or self.name == 'uint32': - return builder.get_int32_ty() - elif self.name == 'int64' or self.name == 'uint64': - return builder.get_int64_ty() - elif self.name == 'fp8': - return builder.get_fp8_ty() - elif self.name == 'fp16': - return builder.get_half_ty() - elif self.name == 'bf16': - return builder.get_bf16_ty() - elif self.name == 'fp32': - return builder.get_float_ty() - elif self.name == 'fp64': - return builder.get_double_ty() - raise ValueError(f'fail to covert {self} to ir type') + def handle(self, builder): + ctx = builder.context + return self.init(ctx) def __str__(self): return self.name @@ -199,112 +99,36 @@ class dtype: return f'triton.language.{self.name}' -class pointer_type(dtype): - def __init__(self, element_ty: dtype, address_space: int = 1): +class pointer_dtype: + def __init__(self, element_ty): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty - self.address_space = address_space - self.name = self.__str__() - - def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) + def handle(self, builder): + return ir.type.make_ptr(self.element_ty.handle(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' - def __repr__(self): - return self.__str__() - - def is_ptr(self): - return True - - def __eq__(self, other: pointer_type) -> bool: - if not isinstance(other, pointer_type): - return False - return self.element_ty == other.element_ty and self.address_space == other.address_space - - def __ne__(self, other: pointer_type) -> bool: - return not self.__eq__(other) - - @property - def scalar(self): - return self - - -class block_type(dtype): - def __init__(self, element_ty: dtype, shape: List[int]): - self.element_ty = element_ty - # FIXME: - # block_type's shape is a list of int - # while tensor's shape is a list of constexpr - self.shape = shape - self.numel = 1 - for s in self.shape: - self.numel *= s - - self.name = self.__str__() - - def to_ir(self, builder: ir.builder) -> ir.block_type: - return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) - - def __str__(self): - return f'<{self.shape}, {self.element_ty}>' - - def __repr__(self): - return self.__str__() - - def is_block(self): - return True - - def get_block_shapes(self) -> List[int]: - return self.shape - - def __eq__(self, other: block_type) -> bool: - if not isinstance(other, block_type): - return False - return self.element_ty == other.element_ty and self.shape == other.shape - - def __ne__(self, other: block_type) -> bool: - return not self.__eq__(other) - - @property - def scalar(self): - return self.element_ty - - -class function_type(dtype): - def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: - self.ret_type = ret_type - self.param_types = param_types - - def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_type}' - - def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) - # scalar types -void = dtype('void') -int1 = dtype('int1') -int8 = dtype('int8') -int16 = dtype('int16') -int32 = dtype('int32') -int64 = dtype('int64') -uint8 = dtype('uint8') -uint16 = dtype('uint16') -uint32 = dtype('uint32') -uint64 = dtype('uint64') -float8 = dtype('fp8') -float16 = dtype('fp16') -bfloat16 = dtype('bf16') -float32 = dtype('fp32') -float64 = dtype('fp64') +int1 = dtype(ir.type.get_int1) +int8 = dtype(ir.type.get_int8) +int16 = dtype(ir.type.get_int16) +int32 = dtype(ir.type.get_int32) +int64 = dtype(ir.type.get_int64) +uint8 = dtype(ir.type.get_uint8) +uint16 = dtype(ir.type.get_uint16) +uint32 = dtype(ir.type.get_uint32) +uint64 = dtype(ir.type.get_uint64) +float8 = dtype(ir.type.get_fp8) +float16 = dtype(ir.type.get_fp16) +bfloat16 = dtype(ir.type.get_bf16) +float32 = dtype(ir.type.get_fp32) +float64 = dtype(ir.type.get_fp64) # pointer types -pi32_t = pointer_type(int32) +pi32_t = pointer_dtype(int32) # ----------------------- # constexpr @@ -325,6 +149,7 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" + # def __add__(self, other): return self.value + other.value @@ -394,33 +219,31 @@ class constexpr: return self.value(*args, **kwds) -class tensor: - # infer dtype from ir type +class block: @staticmethod - def _to_dtype(ir_type): - # block type - if ir_type.is_block(): - scalar_ty = tensor._to_dtype(ir_type.scalar) - return block_type(scalar_ty, ir_type.get_block_shapes()) - # pointer type - if ir_type.is_ptr(): - element_ty = tensor._to_dtype(ir_type.element) - return pointer_type(element_ty) + def _init_dtype(ir_type): # primitive type - if ir_type.is_void(): return void if ir_type.is_int1(): return int1 if ir_type.is_int8(): return int8 if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 + if ir_type.is_uint8(): return uint8 + if ir_type.is_uint16(): return uint16 + if ir_type.is_uint32(): return uint32 + if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 - raise ValueError(f"Unsupported type {ir_type.repr()}") + # pointer type + if ir_type.is_ptr(): + element_ty = block._init_dtype(ir_type.element) + return pointer_dtype(element_ty) + raise ValueError(f"Unsupported type {ir_type}") - def __init__(self, handle, type: dtype): + def __init__(self, handle): # IR handle self.handle = handle # Block shape @@ -431,9 +254,9 @@ class tensor: for s in self.shape: self.numel *= s self.numel = constexpr(self.numel) - self.type = type # Tensor type (can be block_type) - # Following the practice in pytorch, dtype is scalar type - self.dtype = type.scalar + # Data-type wrapper + self.dtype = block._init_dtype(self.handle.type.scalar) + # Shape is a constexpr self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: @@ -442,139 +265,116 @@ class tensor: @builtin def __add__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.add(self, other, _builder) + return frontend.add(self, other, _builder) def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(self, other, _builder) + return frontend.sub(self, other, _builder) def __rsub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(other, self, _builder) + return frontend.sub(other, self, _builder) @builtin def __mul__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.mul(self, other, _builder) + return frontend.mul(self, other, _builder) def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.truediv(self, other, _builder) + return frontend.truediv(self, other, _builder) def __rtruediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.truediv(other, self, _builder) + return frontend.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.floordiv(self, other, _builder) + return frontend.floordiv(self, other, _builder) @builtin def __mod__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.mod(self, other, _builder) + return frontend.mod(self, other, _builder) # unary operators @builtin def __neg__(self, _builder=None): - return semantic.minus(self, _builder) + return frontend.minus(self, _builder) @builtin def __invert__(self, _builder=None): - return semantic.invert(self, _builder) + return frontend.invert(self, _builder) # bitwise operators @builtin def __and__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.and_(self, other, _builder) + return frontend.and_(self, other, _builder) @builtin def __or__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.or_(self, other, _builder) + return frontend.or_(self, other, _builder) @builtin def __xor__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.xor_(self, other, _builder) + return frontend.xor_(self, other, _builder) @builtin def __lshift__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.shl(self, other, _builder) + return frontend.shl(self, other, _builder) @builtin def __rshift__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.lshr(self, other, _builder) + return frontend.lshr(self, other, _builder) # comparison operators # > @builtin def __gt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_than(self, other, _builder) + return frontend.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_than(other, self, _builder) + return frontend.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_equal(self, other, _builder) + return frontend.greater_equal(self, other, _builder) - @builtin def __rge__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_equal(other, self, _builder) + return frontend.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_than(self, other, _builder) + return frontend.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - return semantic.less_than(other, self, _builder) + return frontend.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_equal(self, other, _builder) + return frontend.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_equal(other, self, _builder) + return frontend.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.equal(self, other, _builder) + return frontend.equal(self, other, _builder) @builtin def __ne__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.not_equal(self, other, _builder) + return frontend.not_equal(self, other, _builder) @builtin def __getitem__(self, slices, _builder=None): @@ -589,25 +389,20 @@ class tensor: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = semantic.reshape(self, dst_shape, _builder) + ret = frontend.reshape(self, dst_shape, _builder) return ret @builtin def to(self, dtype, bitcast=False, _builder=None): - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = dtype.handle(_builder) if bitcast: - return semantic.bitcast(self, dtype, _builder) - return semantic.cast(self, dtype, _builder) + return frontend.bitcast(self, dtype, _builder) + return frontend.cast(self, dtype, _builder) # ----------------------- # SPMD Programming Model # ----------------------- -def _constexpr_to_value(v): - if isinstance(v, constexpr): - return v.value - return v @builtin @@ -619,14 +414,13 @@ def program_id(axis, _builder=None): :type axis: int """ # if axis == -1: - # pid0 = program_id(0, _builder) - # pid1 = program_id(1, _builder) - # pid2 = program_id(2, _builder) - # npg0 = num_programs(0, _builder) - # npg1 = num_programs(0, _builder) + # pid0 = frontend.program_id(0, _builder) + # pid1 = frontend.program_id(1, _builder) + # pid2 = frontend.program_id(2, _builder) + # npg0 = frontend.num_programs(0, _builder) + # npg1 = frontend.num_programs(0, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 - axis = _constexpr_to_value(axis) - return semantic.program_id(axis, _builder) + return frontend.program_id(axis, _builder) @builtin @@ -637,8 +431,7 @@ def num_programs(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - axis = _constexpr_to_value(axis) - return semantic.num_programs(axis, _builder) + return frontend.num_programs(axis, _builder) # ----------------------- @@ -656,15 +449,13 @@ def arange(start, end, _builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - start = _constexpr_to_value(start) - end = _constexpr_to_value(end) - return semantic.arange(start, end, _builder) + return frontend.arange(start, end, _builder) @builtin def zeros(shape, dtype, _builder=None): """ - Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints @@ -677,8 +468,7 @@ def zeros(shape, dtype, _builder=None): if not isinstance(d.value, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") shape = [x.value for x in shape] - dtype = _constexpr_to_value(dtype) - return semantic.zeros(shape, dtype, _builder) + return frontend.zeros(shape, dtype, _builder) # ----------------------- @@ -691,25 +481,25 @@ def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. - :param input: The first input tensor. + :param input: The first input block. :type input: Block - :param other: The second input tensor. + :param other: The second input block. :type other: Block """ - return semantic.broadcast_impl_value(input, other, _builder) + return frontend.broadcast(input, other, _builder) @builtin def broadcast_to(input, shape, _builder=None): """ - Tries to broadcast the given tensor to a new :code:`shape`. + Tries to broadcast the given block to a new :code:`shape`. - :param input: The input tensor. + :param input: The input block. :type input: Block :param shape: The desired shape. :type shape: Tuple[int] """ - return semantic.broadcast_impl_shape(input, shape, _builder) + return frontend.broadcast_to(input, shape, _builder) @builtin @@ -717,27 +507,27 @@ def cat(input, other, _builder=None): """ Concatenate the given blocks - :param input: The first input tensor. + :param input: The first input block. :type input: - :param other: The second input tensor. + :param other: The second input block. :type other: """ - return semantic.cat(input, other, _builder) + return frontend.cat(input, other, _builder) @builtin def reshape(input, shape, _builder=None): """ - Tries to reshape the given tensor to a new shape. + Tries to reshape the given block to a new shape. - :param input: The input tensor. + :param input: The input block. :type input: :param shape: The desired shape. :type shape: Tuple[int] """ shape = [x.value for x in shape] - return semantic.reshape(input, shape, _builder) + return frontend.reshape(input, shape, _builder) # ----------------------- @@ -752,13 +542,12 @@ def dot(input, other, allow_tf32=True, _builder=None): The two blocks must be two dimensionals and have compatible inner dimensions. - :param input: The first tensor to be multiplied. - :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} - :param other: The second tensor to be multiplied. - :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input: The first block to be multiplied. + :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second block to be multiplied. + :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, _builder) + return frontend.dot(input, other, allow_tf32, _builder) # ----------------------- @@ -769,7 +558,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ - Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. + Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. @@ -784,36 +573,24 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - # mask, other can be constexpr - if mask is not None: - mask = _to_tensor(mask, _builder) - if other is not None: - other = _to_tensor(other, _builder) - cache_modifier = _constexpr_to_value(cache_modifier) - eviction_policy = _constexpr_to_value(eviction_policy) - volatile = _constexpr_to_value(volatile) - return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) + return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :param pointer: The memory locations where the elements of :code:`value` are stored. :type pointer: Block of dtype=triton.PointerDType - :param value: The tensor of elements to be stored. + :param value: The block of elements to be stored. :type value: Block :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - # value can be constexpr - value = _to_tensor(value, _builder) - if mask is not None: - mask = _to_tensor(mask, _builder) - return semantic.store(pointer, value, mask, _builder) + return frontend.store(pointer, value, mask, _builder) # ----------------------- @@ -844,58 +621,49 @@ def _add_atomic_docstr(name): @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): - cmp = _to_tensor(cmp, _builder) - val = _to_tensor(cmp, _builder) - return semantic.atomic_cas(pointer, cmp, val, _builder) + return frontend.atomic_cas(pointer, cmp, val, _builder) @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_xchg(pointer, val, mask, _builder) + return frontend.atomic_xchg(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_add(pointer, val, mask, _builder) + return frontend.atomic_add(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_max(pointer, val, mask, _builder) + return frontend.atomic_max(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_min(pointer, val, mask, _builder) + return frontend.atomic_min(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_and(pointer, val, mask, _builder) + return frontend.atomic_and(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_or(pointer, val, mask, _builder) + return frontend.atomic_or(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_xor(pointer, val, mask, _builder) + return frontend.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -906,7 +674,7 @@ def atomic_xor(pointer, val, mask=None, _builder=None): @builtin def where(condition, x, y, _builder=None): """ - Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. @@ -920,10 +688,7 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - condition = _to_tensor(condition, _builder) - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) - return semantic.where(condition, x, y, _builder) + return frontend.where(condition, x, y, _builder) # ----------------------- @@ -932,15 +697,12 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) - return semantic.umulhi(x, y, _builder) + return frontend.umulhi(x, y, _builder) @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): - ieee_rounding = _constexpr_to_value(ieee_rounding) - return semantic.fdiv(x, y, ieee_rounding, _builder) + return frontend.fdiv(x, y, ieee_rounding, _builder) def _add_math_1arg_docstr(name): @@ -961,31 +723,31 @@ def _add_math_1arg_docstr(name): @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): - return semantic.exp(x, _builder) + return frontend.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): - return semantic.log(x, _builder) + return frontend.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): - return semantic.cos(x, _builder) + return frontend.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): - return semantic.sin(x, _builder) + return frontend.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): - return semantic.sqrt(x, _builder) + return frontend.sqrt(x, _builder) # ----------------------- @@ -996,7 +758,7 @@ def _add_reduction_docstr(name): def _decorator(func): docstr = """ - Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done @@ -1010,29 +772,25 @@ def _add_reduction_docstr(name): @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.max(input, axis, _builder) + return frontend.max(input, axis, _builder) @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.min(input, axis, _builder) + return frontend.min(input, axis, _builder) @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.sum(input, axis, _builder) + return frontend.sum(input, axis, _builder) @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.xor_sum(input, axis, _builder) + return frontend.xor_sum(input, axis, _builder) # ----------------------- @@ -1042,7 +800,7 @@ def xor_sum(input, axis, _builder=None): @builtin def debug_barrier(_builder=None): - return semantic.debug_barrier(_builder) + return frontend.debug_barrier(_builder) @builtin @@ -1050,8 +808,7 @@ def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - value = _constexpr_to_value(value) - return semantic.multiple_of(input, value) + return frontend.multiple_of(input, value, _builder) @builtin @@ -1059,8 +816,7 @@ def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - value = _constexpr_to_value(value) - return semantic.max_contiguous(input, value) + return frontend.max_contiguous(input, value, _builder) # ----------------------- @@ -1090,9 +846,9 @@ def minimum(x, y): """ Computes the element-wise minimum of :code:`x` and :code:`y`. - :param input: the first input tensor + :param input: the first input block :type input: Block - :param other: the second input tensor + :param other: the second input block :type other: Block """ return triton.language.where(x < y, x, y) @@ -1103,9 +859,9 @@ def maximum(x, y): """ Computes the element-wise maximum of :code:`x` and :code:`y`. - :param input: the first input tensor + :param input: the first input block :type input: Block - :param other: the second input tensor + :param other: the second input block :type other: Block """ return triton.language.where(x > y, x, y) @@ -1131,7 +887,7 @@ def ravel(x): """ Returns a contiguous flattened view of :code:`x` - :param x: the input tensor + :param x: the input block :type x: Block """ return triton.language.reshape(x, [x.numel]) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py deleted file mode 100644 index 4063b86fc..000000000 --- a/python/triton/language/semantic.py +++ /dev/null @@ -1,1037 +0,0 @@ -from __future__ import annotations # remove after python 3.11 - -from typing import List, Optional, Tuple - -from . import core as tl -from triton._C.libtriton.triton import ir - - -# Create custom exception that prints message "hello" -class IncompatibleTypeErrorimpl(Exception): - def __init__(self, type_a, type_b): - self.type_a = type_a - self.type_b = type_b - self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() - super(IncompatibleTypeErrorimpl, self).__init__(self.message) - - -# ===----------------------------------------------------------------------===## -# Programming Model -# ===----------------------------------------------------------------------===## - -def program_id(axis: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_get_program_id(axis), tl.int32) - - -def num_programs(axis: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_get_num_programs(axis), tl.int32) - -# ===----------------------------------------------------------------------===// -# Implicit Casting Utilities -# ===----------------------------------------------------------------------===// - - -def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: - a_rank = a_ty.int_bitwidth - b_rank = b_ty.int_bitwidth - a_sn = a_ty.int_signedness - b_sn = b_ty.int_signedness - # Rules for signedness taken from "Usual arithmetic conversions" on - # https://en.cppreference.com/w/c/language/conversion. - if a_sn == b_sn: - return a_ty if a_rank > b_rank else b_ty - elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: - return a_ty if a_rank >= b_rank else b_ty - elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: - return b_ty if b_rank >= a_rank else a_ty - assert False - - -def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: - # 1) if one operand is double, the other is implicitly - # converted to double - if a_ty.is_fp64() or b_ty.is_fp64(): - return tl.float64 - # 2) if one operand is float, the other is implicitly - # converted to float - if a_ty.is_fp32() or b_ty.is_fp32(): - return tl.float32 - # 3 ) if one operand is half, the other is implicitly converted to half - # unless we're doing / or %, which do not exist natively in PTX for fp16. - if a_ty.is_fp16() or b_ty.is_fp16(): - if div_or_mod: - return tl.float32 - else: - return tl.float16 - if not a_ty.is_int() or not b_ty.is_int(): - assert False - # 4 ) both operands are integer and undergo - # integer promotion - if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: - raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" - "this is unlikely to result in a useful answer. Cast them to the same signedness.") - return integer_promote_impl(a_ty, b_ty) - -# ===----------------------------------------------------------------------===// -# Binary Operators -# ===----------------------------------------------------------------------===// - - -def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: - if type_a.is_ptr(): - if not allow_ptr_a: - raise IncompatibleTypeErrorimpl(type_a, type_b) - # T* + U* with T != U - if type_b.is_ptr() and (type_a != type_b): - raise IncompatibleTypeErrorimpl(type_a, type_b) - # T* + float - if type_b.is_floating(): - raise IncompatibleTypeErrorimpl(type_a, type_b) - - -def binary_op_type_checking_impl(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder, - allow_lhs_ptr=False, allow_rhs_ptr=False, - arithmetic_check=True, div_or_mod=False - ) -> Tuple[tl.tensor, tl.tensor]: - # implicit broadcasting - lhs, rhs = broadcast_impl_value(lhs, rhs, builder) - # implicit typecasting - lhs_sca_ty = lhs.type.scalar - rhs_sca_ty = rhs.type.scalar - check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) - check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) - if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): - ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) - lhs = cast(lhs, ret_sca_ty, builder) - rhs = cast(rhs, ret_sca_ty, builder) - return lhs, rhs - - -def add(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - - # offset + ptr - # ptr + offset - if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): - input, other = other, input - if input_scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) - # float + float - elif input_scalar_ty.is_floating(): - return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) - # int + int - elif input_scalar_ty.is_int(): - return tl.tensor(builder.create_add(input.handle, other.handle), input.type) - assert False - - -def sub(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, True, False) - scalar_ty = input.type.scalar - # ptr - offset - if scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), - input.type) - # float - float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) - # int - int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) - assert False - - -def mul(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float * float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) - # * int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) - assert False - - -def truediv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - # float / int - if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): - other = cast(other, input_scalar_ty, builder) - # int / float - elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): - input = cast(input, other_scalar_ty, builder) - # int / int (cast to tl.float32) - elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): - input = cast(input, tl.float32, builder) - other = cast(other, tl.float32, builder) - # float / float (cast to highest exponent type) - elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): - if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: - other = cast(other, input_scalar_ty, builder) - else: - input = cast(input, other_scalar_ty, builder) - # unreachable - else: - assert False - return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) - - -def floordiv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - if input_scalar_ty.is_int() and other_scalar_ty.is_int(): - ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) - input = cast(input, ret_ty, builder) - other = cast(other, ret_ty, builder) - if ret_ty.is_int_signed(): - return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) - else: - return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) - assert False - - -def fdiv(input: tl.tensor, - other: tl.tensor, - ieee_rounding: bool, - builder: ir.builder) -> tl.tensor: - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): - raise ValueError("both operands of fdiv must have floating poscalar type") - input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) - ret = builder.create_fdiv(input.handle, other.handle) - ret.set_fdiv_ieee_rounding(ieee_rounding) - return tl.tensor(ret, input.type) - - -def mod(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - # float % float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) - # % int - elif scalar_ty.is_int(): - if scalar_ty.int_signedness != other_scalar_ty.int_signedness: - raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " - "because they have different signedness;" - "this is unlikely to result in a useful answer. Cast them to the same signedness.") - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) - else: - return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) - assert False - -############## -# bitwise ops -############## - - -def bitwise_op_type_checking_impl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) - input_sca_ty = input.type.scalar - other_sca_ty = other.type.scalar - if not input_sca_ty.is_int() or not other_sca_ty.is_int(): - raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty) - ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) - if ret_sca_ty != input_sca_ty: - input = cast(input, ret_sca_ty, builder) - if ret_sca_ty != other_sca_ty: - other = cast(other, ret_sca_ty, builder) - return input, other - - -def and_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_and(input.handle, other.handle), input.type) - - -def or_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_or(input.handle, other.handle), input.type) - - -def xor_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) - - -def lshr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) - - -def shl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) - -# ===----------------------------------------------------------------------===// -# Unary Operators -# ===----------------------------------------------------------------------===// - - -def plus(input: tl.tensor) -> tl.tensor: - return input - - -def minus(input: tl.tensor, - builder: ir.builder) -> tl.tensor: - input_sca_ty = input.type.scalar - if input_sca_ty.is_ptr(): - raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") - _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return sub(_0, input, builder) - - -def invert(input: tl.tensor, - builder: tl.tensor) -> tl.tensor: - input_sca_ty = input.type.scalar - if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): - raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") - _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return xor_(input, _1, builder) - - -# ===----------------------------------------------------------------------===// -# Comparison Operators -# ===----------------------------------------------------------------------===// -def _bool_like(v: tl.tensor) -> tl.block_type: - if not v.type.is_block(): - return tl.int1 - shape = v.type.shape - return tl.block_type(tl.int1, shape) - - -def greater_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float > float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) - # > int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) - assert False - - -def greater_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float >= float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) - # >= int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) - assert False - - -def less_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float < float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) - # < int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) - assert False - - -def less_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float < float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) - # < int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) - assert False - - -def equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float == float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) - # == int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) - assert False - - -def not_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float == float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) - # == int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) - assert False - -# ===----------------------------------------------------------------------===// -# Block Creation -# ===----------------------------------------------------------------------===// - - -def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: - shape = [end - start] - ret_ty = tl.block_type(tl.int32, shape) - return tl.tensor(builder.get_range(start, end), ret_ty) - - -def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - _0 = ir.constant.get_null_value(dtype.to_ir(builder)) - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_splat(_0, shape), ret_ty) - -# ===----------------------------------------------------------------------===// -# Shape Manipulation -# ===----------------------------------------------------------------------===// - - -def reshape(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: - numel = 1 - for s in dst_shape: - numel *= s - if input.type.numel != numel: - raise ValueError("cannot reshape block of different shape") - ret_ty = tl.block_type(input.type.scalar, dst_shape) - return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) - - -def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: check types - return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) - - -def broadcast_impl_shape(input: tl.tensor, - shape: List[int], - builder: ir.builder) -> tl.tensor: - if not input.type.is_block(): - ret_ty = tl.block_type(input.type, shape) - return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) - src_shape = input.type.get_block_shapes() - if len(src_shape) != len(shape): - raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") - if shape == src_shape: - return input - ret_ty = tl.block_type(input.type.scalar, shape) - return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) - - -def broadcast_impl_value(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder) -> tl.tensor: - lhs_ty = lhs.type - rhs_ty = rhs.type - - # make_shape_compatible(block, scalar) - if lhs_ty.is_block() and not rhs_ty.is_block(): - rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) - rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) - # make_shape_compatible(scalar, block) - elif not lhs_ty.is_block() and rhs_ty.is_block(): - lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) - lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) - # make_shape_compatible(block, block) - elif lhs_ty.is_block() and rhs_ty.is_block(): - lhs_shape = lhs_ty.get_block_shapes() - rhs_shape = rhs_ty.get_block_shapes() - if len(lhs_shape) != len(rhs_shape): - raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") - ret_shape = [] - for i in range(len(lhs_shape)): - left = lhs_shape[i] - right = rhs_shape[i] - if left == 1: - ret_shape.append(right) - elif right == 1: - ret_shape.append(left) - elif left == right: - ret_shape.append(left) - else: - raise ValueError("Cannot make_shape_compatible: incompatible dimensions " - "at index " + str(i) + ": " + str(left) + " and " + str(right)) - if lhs_shape != ret_shape: - ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) - lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) - if rhs_shape != ret_shape: - ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) - rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) - # (scalar, scalar) => returns original blocks - return lhs, rhs - -####### -# cast -####### - - -def bitcast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: - src_ty = input.type - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): - return cast(input, dst_ty, builder) - # Bitcast - src_bits = src_sca_ty.primitive_bitwidth - dst_bits = dst_sca_ty.primitive_bitwidth - if src_bits != dst_bits: - raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " - "data-type of size " + str(dst_bits)) - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), - dst_ty) - - -def cast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: - src_ty = input.type - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - - # bf16 <=> (not fp32) - if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ - (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): - return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) - - # FP Truncation - truncate_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width - if truncate_fp: - return tl.tensor(builder.create_fp_trunc(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # FP Extension - ext_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width - if ext_fp: - return tl.tensor(builder.create_fp_ext(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # Int cast - if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ - (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): - sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - return tl.tensor(builder.create_int_cast(input.handle, - dst_ty.to_ir(builder), sign_extend), - dst_ty) - - # Float to Int - if src_sca_ty.is_floating() and dst_sca_ty.is_int(): - # TODO: is this correct? - if dst_sca_ty.is_bool(): - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - else: - return tl.tensor(builder.create_fp_to_si(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # int => float - if src_sca_ty.is_int() and dst_sca_ty.is_floating(): - if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tl.tensor(builder.create_ui_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - else: - return tl.tensor(builder.create_si_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # ptr => int - if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): - bitwidth = dst_sca_ty.int_bitwidth - if bitwidth == 64: - return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), - dst_ty) - if bitwidth == 1: - return not_equal(cast(input, tl.int64, builder), - tl.tensor(builder.get_int64(0), tl.int64), - builder) - - if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - # Ptr . Ptr - if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - # * . Bool - if dst_sca_ty.is_bool(): - if src_sca_ty.is_ptr(): - input = cast(input, tl.int64, builder) - other = builder.get_int64(0) - if src_ty.is_bool(): - other = builder.create_splat(other, src_ty.get_block_shapes()) - return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) - assert False, f'cannot cast {input} to {dst_ty}' - -# ===----------------------------------------------------------------------===// -# Memory Operators -# ===----------------------------------------------------------------------===// - - -def load(ptr: tl.tensor, - mask: Optional[tl.tensor], - other: Optional[tl.tensor], - cache_modifier: str, - eviction_policy: str, - is_volatile: bool, - builder: ir.builder) -> tl.tensor: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - if other: - other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) - - if other: - other = cast(other, ptr.type.scalar.element_ty, builder) - ptr_ty = ptr.type.scalar - elt_ty = ptr_ty.element_ty - # treat bool* as tl.int8* - if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) - ptr = cast(ptr, ptr_ty, builder) - - # cache modifier - cache = ir.CACHE_MODIFIER.NONE # default - if cache_modifier: - if cache_modifier == ".ca": - cache = ir.CACHE_MODIFIER.CA - elif cache_modifier == ".cg": - cache = ir.CACHE_MODIFIER.CG - else: - raise ValueError(f"Cache modifier {cache_modifier} not supported") - - # eviction policy - eviction = ir.EVICTION_POLICY.NORMAL # default - if eviction_policy: - if eviction_policy == "evict_last": - eviction = ir.EVICTION_POLICY.EVICT_LAST - elif eviction_policy == "evict_first": - eviction = ir.EVICTION_POLICY.EVICT_FIRST - else: - raise ValueError(f"Eviction policy {eviction_policy} not supported") - - if ptr.type.is_block(): - shape = ptr.type.get_block_shapes() - dst_ty = tl.block_type(elt_ty, shape) - else: - dst_ty = elt_ty - - if not mask and not other: - return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), - dst_ty) - if not mask: - raise ValueError("`other` cannot be provided without `mask`") - - if not other: - other_ir = ir.undef.get(elt_ty.to_ir(builder)) - if ptr.type.is_block(): - other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes()) - other = tl.tensor(other_ir, dst_ty) - - return tl.tensor(builder.create_masked_load(ptr.handle, - mask.handle, - other.handle, - cache, eviction, is_volatile), - dst_ty) - - -def store(ptr: tl.tensor, - val: tl.tensor, - mask: Optional[tl.tensor], - builder: ir.builder) -> tl.tensor: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - ptr_ty = ptr.type.scalar - elt_ty = ptr_ty.element_ty - # treat bool* as tl.int8* - if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) - ptr = cast(ptr, ptr_ty, builder) - - # cast to target data-type - val = cast(val, elt_ty, builder) - if not mask: - return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) - if not mask.type.scalar.is_bool(): - raise ValueError("Mask must have boolean scalar type") - return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) - -######### -# atomic -######### - - -def atomic_cas(ptr: tl.tensor, - cmp: tl.tensor, - val: tl.tensor, - builder: ir.builder) -> tl.tensor: - # TODO: type checking - return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) - - -def atom_red_typechecking_impl(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - if val: - val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) - val = cast(val, ptr.type.scalar.element_ty, builder) - if not mask: - mask_ir = builder.get_int1(True) - mask_ty = tl.int1 - if ptr.type.is_block(): - mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) - mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) - mask = tl.tensor(mask_ir, mask_ty) - return ptr, val, mask - - -def atomic_max(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - # direct call to atomic_max for integers - if sca_ty.is_int(): - if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, - ptr.handle, - val.handle, - mask.handle), - val.type) - else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - ptr.handle, - val.handle, - mask.handle), - val.type) - # for float - # return atomic_smax(i_ptr, i_val) if val >= 0 - # return atomic_umin(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) - return where(pos, pos_ret, neg_ret, builder) - - -def atomic_min(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - # direct call to atomic_min for integers - if sca_ty.is_int(): - if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - ptr.handle, - val.handle, - mask.handle), - val.type) - else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, - ptr.handle, - val.handle, - mask.handle), - val.type) - # for float - # return atomic_smin(i_ptr, i_val) if val >= 0 - # return atomic_umax(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - i_ptr.handle, - i_val.handle, - and_(mask, pos, builder).handle), - i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - i_ptr.handle, - i_val.handle, - and_(mask, neg, builder).handle), - i_val.type) - return where(pos, pos_ret, neg_ret, builder) - - -def atomic_add(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD - return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_and(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_or(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_xor(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_xchg(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) - -# ===----------------------------------------------------------------------===// -# Linear Algebra -# ===----------------------------------------------------------------------===// - - -def dot(lhs: tl.tensor, - rhs: tl.tensor, - allow_tf32: bool, - builder: ir.builder) -> tl.tensor: - assert lhs.type.is_block() and rhs.type.is_block() - if lhs.type.scalar.is_int(): - _0 = builder.get_int32(0) - ret_scalar_ty = tl.int32 - else: - _0 = builder.get_float32(0) - ret_scalar_ty = tl.float32 - M = lhs.type.shape[0] - N = rhs.type.shape[1] - _0 = builder.create_splat(_0, [M, N]) - ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) - - -# ===----------------------------------------------------------------------===// -# Indexing -# ===----------------------------------------------------------------------===// - -def where(condition: tl.tensor, - x: tl.tensor, - y: tl.tensor, - builder: ir.builder) -> tl.tensor: - condition = cast(condition, tl.int1, builder) - if condition.type.is_block(): - x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) - y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) - - # TODO: we need to check x's and y's shape? - x_ty = x.type.scalar - y_ty = y.type.scalar - ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) - x = cast(x, ty, builder) - y = cast(y, ty, builder) - if x.type.is_block(): - ret_ty = tl.block_type(ty, x.type.shape) - else: - ret_ty = ty - return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) - - -# ===----------------------------------------------------------------------===// -# Reductions -# ===----------------------------------------------------------------------=== - -def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, - FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: - scalar_ty = input.type.scalar - # input is extended to 32-bits if necessary - # this increases numerical accuracy and can be done pretty much for free - # on GPUs - if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: - input = cast(input, tl.int32, builder) - - # get result type - shape = input.type.shape - ret_shape = [] - for i, s in enumerate(shape): - if i != axis: - ret_shape.append(s) - if len(ret_shape) == 0: - res_ty = scalar_ty - else: - res_ty = tl.block_type(scalar_ty, ret_shape) - - if scalar_ty.is_floating(): - return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) - elif scalar_ty.is_int(): - return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) - assert False - - -def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) - - -def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) - - -def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) - - -def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - if not scalar_ty.is_int(): - raise ValueError("xor_sum only supported for integers") - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) - - -# ===----------------------------------------------------------------------=== -# Math -# ===----------------------------------------------------------------------=== - -def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: - x, y = binary_op_type_checking_impl(x, y, builder) - return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) - - -def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_exp(x.handle), x.type) - - -def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_log(x.handle), x.type) - - -def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_cos(x.handle), x.type) - - -def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_sin(x.handle), x.type) - - -def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_sqrt(x.handle), x.type) - - -## - -def multiple_of(x: tl.tensor, value: int) -> tl.tensor: - x.handle.multiple_of(value) - return x - - -def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: - x.handle.max_contiguous(value) - return x - - -def debug_barrier(builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_barrier(''), tl.void) From e0cc48805506b1f3cd99a08d25e331ed3dbf8a45 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 28 Mar 2022 16:15:43 -0700 Subject: [PATCH 082/215] [FRONTEND] Added `tl.clock` and `tl.globaltimer` (#485) --- include/triton/codegen/selection/generator.h | 3 ++ include/triton/ir/builder.h | 4 ++- include/triton/ir/dispatch.h | 4 +++ include/triton/ir/enums.h | 2 ++ include/triton/ir/instructions.h | 21 +++++++++++++ include/triton/ir/visitor.h | 4 +++ lib/codegen/selection/generator.cc | 16 ++++++++-- lib/driver/llvm.cc | 3 +- lib/ir/dispatch.cc | 10 ++++++ lib/ir/instructions.cc | 33 ++++++++------------ python/src/triton.cc | 3 ++ python/triton/language/core.py | 13 ++++++++ python/tutorials/01-vector-add.py | 8 ++++- 13 files changed, 99 insertions(+), 25 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index ad7d01a55..293aa8908 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -182,6 +182,8 @@ public: void visit_async_wait_inst(ir::async_wait_inst*); // void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range(ir::make_range*); + void visit_clock_inst(ir::clock_inst*); + void visit_globaltimer_inst(ir::globaltimer_inst*); // void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); @@ -192,6 +194,7 @@ public: void visit_argument(ir::argument*); void visit(ir::module &, llvm::Module &); + // layouts void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 2b6bc6ab3..45a7d5111 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -28,7 +28,9 @@ public: // Constructor builder(context &ctx); // Getters - const context& get_context() { return ctx_; } + // const context& get_context() const { return ctx_; } + context& get_context() { return ctx_; } + // Setters void set_insert_point(iterator instr); void set_insert_point(instruction* i); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index ef14043dd..c7f23779c 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -101,6 +101,10 @@ struct dispatch{ static ir::value *sin(ir::value *x, ir::builder *builder); static ir::value *sqrt(ir::value *x, ir::builder *builder); + // utilities + static ir::value *globaltimer(ir::builder *builder); + static ir::value *clock(ir::builder *builder); + // internal (debug/optimization) static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 8cb7835f0..2d4c09d79 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -165,6 +165,8 @@ enum value_id_t: unsigned { INST_MAKE_RANGE_STA, INST_MAKE_RANGE, INST_PREFETCH_S, + INST_GLOBALTIMER, + INST_CLOCK, }; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 0fb85db02..e9e0f0f11 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -971,6 +971,27 @@ private: constant_int* last_; }; +/* timing utilities */ +class clock_inst: public instruction{ + clock_inst(context &ctx, const std::string &name, instruction *next); + std::string repr_impl() const { return "clock"; } + _TRITON_DEFINE_CLONE(clock_inst) + _TRITON_DEFINE_ACCEPT(clock_inst) + +public: + static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); +}; + +class globaltimer_inst: public instruction{ + globaltimer_inst(context &ctx, const std::string &name, instruction *next); + std::string repr_impl() const { return "globaltimer"; } + _TRITON_DEFINE_CLONE(globaltimer_inst) + _TRITON_DEFINE_ACCEPT(globaltimer_inst) + +public: + static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); +}; + } } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 4979b0b52..25ce578e3 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -75,6 +75,8 @@ class async_wait_inst; class make_range_dyn; class make_range; class prefetch_s_inst; +class clock_inst; +class globaltimer_inst; class make_range_sta; class undef_value; @@ -157,6 +159,8 @@ public: virtual void visit_make_range(make_range*) = 0; virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; virtual void visit_function(function*) = 0; + virtual void visit_clock_inst(clock_inst*) = 0; + virtual void visit_globaltimer_inst(globaltimer_inst*) = 0; virtual void visit_undef_value(undef_value*) = 0; virtual void visit_constant_int(constant_int*) = 0; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f8cf08cba..b36f51d92 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1093,10 +1093,10 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { case tt::Xchg: name = "exch", s_ty = "b"; break; } std::string s_vec = vec == 2 ? "x2" : ""; - std::string mod = nbits == 32 ? "" : ".noftz"; + std::string mod = nbits == 16 ? ".noftz" : ""; std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;"; - std::string ty_id = nbits*vec == 32 ? "r" : "h"; + std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h"); std::string constraint = "=" + ty_id + ",b,l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); @@ -2597,6 +2597,18 @@ void generator::visit_barrier_inst(ir::barrier_inst*) { add_barrier(); } +void generator::visit_clock_inst(ir::clock_inst* clock){ + InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true); + vals_[clock][{}] = call(iasm); +} + +void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){ + InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true); + vals_[timer][{}] = call(iasm); +} + + + void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { ir::value *v = i->get_operand(0); int inc = i->get_inc(); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 0d1c9c3d2..726ac9a97 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -171,13 +171,14 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ // create machine module->setTargetTriple(triple); std::string error; + llvm::TargetMachine* machine; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); llvm::TargetOptions opt; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, + machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); // set data layout if(layout.empty()) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 664fbb983..aabbc4385 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -854,6 +854,16 @@ ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { return builder->create_sqrt(x); } +// + +ir::value *dispatch::globaltimer(ir::builder *builder) { + return builder->insert(globaltimer_inst::create(builder->get_context())); +} + +ir::value *dispatch::clock(ir::builder *builder) { + return builder->insert(clock_inst::create(builder->get_context())); + +} // diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index c225b315f..d1f81f136 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -861,8 +861,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri } // barrier -barrier_inst::barrier_inst(context &ctx, const std::string &name, - instruction *next) +barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next) : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { } barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { @@ -881,27 +880,21 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons return new prefetch_s_inst(ctx, arg, inc, name, next); } -//// nv_dynamic_program_idx -//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) -// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } +// global timer +globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next) + : instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { } -//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { -// return new make_range_dyn(ty, name, next); -//} +globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) { + return new globaltimer_inst(ctx, name, next); +} -//// nv_static_program_idx -//make_range_sta::make_range_sta(make_range *range) -// : constant(range->get_type(), 0), range_(range) { } +// clock +clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next) + : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { } -//make_range* make_range_sta::get_range() const -//{ return range_; } - -//make_range_sta* make_range_sta::get(make_range* range) { -// static std::map cache; -// if(cache.find(range) == cache.end()) -// cache.insert({range, new make_range_sta(range)}); -// return cache.at(range); -//} +clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) { + return new clock_inst(ctx, name, next); +} // make_range diff --git a/python/src/triton.cc b/python/src/triton.cc index 9e53cc341..22017ebf5 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -614,6 +614,9 @@ void init_triton_frontend(py::module &&m) { m.def("cos", &ir::dispatch::cos, ret::reference); m.def("sin", &ir::dispatch::sin, ret::reference); m.def("sqrt", &ir::dispatch::sqrt, ret::reference); + // utilities + m.def("clock", &ir::dispatch::clock, ret::reference); + m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference); // internal (debugging only) m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df25e59fb..0312d8146 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -792,6 +792,19 @@ def sum(input, axis, _builder=None): def xor_sum(input, axis, _builder=None): return frontend.xor_sum(input, axis, _builder) +# ----------------------- +# Utilities +# ----------------------- + + +@builtin +def globaltimer(_builder=None): + return frontend.globaltimer(_builder) + + +@builtin +def clock(_builder=None): + return frontend.clock(_builder) # ----------------------- # Internal for debugging diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index d684106f1..c0fb85328 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,9 +24,11 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector + time_start_ptr, time_end_ptr, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process # NOTE: `constexpr` so it can be used as a shape value ): + tl.atomic_min(time_start_ptr, tl.clock()) # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -45,6 +47,7 @@ def add_kernel( output = x + y # Write x + y back to DRAM tl.store(output_ptr + offsets, output, mask=mask) + tl.atomic_max(time_end_ptr, tl.clock()) # %% @@ -53,6 +56,8 @@ def add_kernel( def add(x: torch.Tensor, y: torch.Tensor): + time_start = torch.zeros(1, dtype=torch.int64, device='cuda') + time_end = torch.zeros(1, dtype=torch.int64, device='cuda') # We need to preallocate the output output = torch.empty_like(x) assert x.is_cuda and y.is_cuda and output.is_cuda @@ -65,9 +70,10 @@ def add(x: torch.Tensor, y: torch.Tensor): # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. + print((time_end, time_start)) return output From bace26143df15dbf3da7ee769b050c1471eb4754 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 28 Mar 2022 16:53:23 -0700 Subject: [PATCH 083/215] [TUTORIALS] Removed leftover print --- python/tutorials/01-vector-add.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index c0fb85328..51de7ac6c 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -73,7 +73,6 @@ def add(x: torch.Tensor, y: torch.Tensor): add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. - print((time_end, time_start)) return output From e85c7a7fc7cd48349905d543884300e1e726174a Mon Sep 17 00:00:00 2001 From: apd10 <57877560+apd10@users.noreply.github.com> Date: Wed, 30 Mar 2022 22:45:41 -0500 Subject: [PATCH 084/215] Bugfix in ptxas path. (#487) Bug: "ret" value is destroyed when a failing "ptxas --version" is run overwriting the previous valid "ret" value. Fix: keep rets only for those runs which are successful. Pick the first one --- lib/driver/llvm.cc | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 726ac9a97..463f45712 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -94,6 +94,7 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s } std::string path_to_ptxas(int& version) { + std::vector rets; std::string ret; // search pathes for ptxas std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; @@ -105,8 +106,10 @@ std::string path_to_ptxas(int& version) { for(std::string prefix: ptxas_prefixes){ std::string ptxas = prefix + "ptxas"; bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; - if(works) + if(works) { working_ptxas.push_back(ptxas); + rets.push_back(ret); + } } // error if no working ptxas was found if(working_ptxas.empty()) @@ -116,13 +119,20 @@ std::string path_to_ptxas(int& version) { // parse version std::regex version_regex("release (\\d+)\\.(\\d+)"); std::smatch match; - if(std::regex_search(ret, match, version_regex)){ - int major = std::stoi(match[1]); - int minor = std::stoi(match[2]); - version = major*1000 + minor*10; + bool found = false; + // currently choosing the first ptxas. Other logics can be implemented in future + for(std::string ret : rets) { + if(std::regex_search(ret, match, version_regex)){ + int major = std::stoi(match[1]); + int minor = std::stoi(match[2]); + version = major*1000 + minor*10; + found = true; + break; + } + } + if ( not found) { + throw std::runtime_error("Error in parsing version"); } - else - throw std::runtime_error("couldn't parse ptxas version: " + ret); return ptxas; } From 2bed6fc850b93854d7f411e77c9a302954604091 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 3 Apr 2022 20:58:16 -0700 Subject: [PATCH 085/215] [LANG] Added support for device functions (#484) --- include/triton/codegen/analysis/layout.h | 1 + include/triton/codegen/selection/generator.h | 8 + include/triton/codegen/transform/inline.h | 31 +++ include/triton/codegen/transform/peephole.h | 3 + include/triton/driver/dispatch.h | 2 + include/triton/external/CUDA/cuda.h | 244 +++++++---------- include/triton/ir/basic_block.h | 16 +- include/triton/ir/builder.h | 8 + include/triton/ir/context_impl.h | 3 +- include/triton/ir/enums.h | 6 + include/triton/ir/function.h | 5 +- include/triton/ir/instructions.h | 83 ++++++ include/triton/ir/module.h | 88 ++++--- include/triton/ir/type.h | 18 +- include/triton/ir/value.h | 4 +- include/triton/ir/visitor.h | 11 + lib/codegen/analysis/layout.cc | 2 + lib/codegen/pass.cc | 3 + lib/codegen/selection/generator.cc | 155 ++++++++++- lib/codegen/transform/dce.cc | 4 + lib/codegen/transform/inline.cc | 127 +++++++++ lib/codegen/transform/peephole.cc | 70 +++-- lib/codegen/transform/pipeline.cc | 1 + lib/driver/dispatch.cc | 1 + lib/driver/error.cc | 2 +- lib/driver/llvm.cc | 2 + lib/ir/basic_block.cc | 63 ++++- lib/ir/builder.cc | 33 ++- lib/ir/dispatch.cc | 3 + lib/ir/function.cc | 6 +- lib/ir/instructions.cc | 100 ++++++- lib/ir/module.cc | 85 +++--- lib/ir/type.cc | 21 +- lib/ir/value.cc | 5 +- python/setup.py | 2 +- python/src/triton.cc | 59 ++++- python/test/unit/language/test_core.py | 30 ++- python/triton/code_gen.py | 262 +++++++++++++------ python/triton/language/core.py | 25 +- 39 files changed, 1213 insertions(+), 379 deletions(-) create mode 100644 include/triton/codegen/transform/inline.h mode change 100755 => 100644 include/triton/external/CUDA/cuda.h create mode 100644 lib/codegen/transform/inline.cc diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 56fb1e4b9..28dfad18d 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,6 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } + int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 293aa8908..e3191efb1 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -24,6 +24,7 @@ namespace llvm{ class IRBuilder; class ArrayType; class Function; + class StructType; } namespace triton{ @@ -114,6 +115,8 @@ private: private: Type *cvt(ir::type *ty); llvm::Attribute cvt(ir::attribute attr); + llvm::StructType* packed_type(ir::value* i); + void forward_declare(ir::function* fn); public: generator(analysis::axes *a_axes, @@ -125,6 +128,8 @@ public: unsigned num_warps); void visit_value(ir::value* v); + void visit_call_inst(ir::call_inst*); + void visit_launch_inst(ir::launch_inst *); void visit_phi_node(ir::phi_node*); void visit_binary_operator(ir::binary_operator*); void visit_getelementptr_inst(ir::getelementptr_inst*); @@ -148,6 +153,8 @@ public: void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*); void visit_cat_inst(ir::cat_inst*); + void visit_extract_value_inst(ir::extract_value_inst *); + void visit_insert_value_inst(ir::insert_value_inst *); void visit_reshape_inst(ir::reshape_inst*); void visit_splat_inst(ir::splat_inst*); void visit_broadcast_inst(ir::broadcast_inst*); @@ -242,6 +249,7 @@ private: /// triton bb -> llvm bb std::map bbs_; std::map> ords_; + std::map fns_; // helper for creating llvm values adder add; diff --git a/include/triton/codegen/transform/inline.h b/include/triton/codegen/transform/inline.h new file mode 100644 index 000000000..c79079b61 --- /dev/null +++ b/include/triton/codegen/transform/inline.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace triton { + +namespace ir { + class module; + class function; + class call_inst; + class builder; +} + +namespace codegen{ +namespace transform{ + +struct fncmp { + bool operator()(ir::function* x, ir::function* y) const; +}; + +class inliner { +public: + inliner() {} + void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list& callsites); + void run(ir::module &mod); +}; + + +} +} +} diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 0e1ed222e..5b84a813b 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -30,6 +30,9 @@ private: bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); bool rewrite_dot(ir::instruction *value, ir::builder& builder); bool rewrite_mult(ir::instruction *value, ir::builder& builder); + bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder); + + bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 5503bacaf..2384b4cba 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -89,6 +89,7 @@ public: static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); static CUresult cuDeviceGetCount(int *count); // link management + static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues); static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); @@ -214,6 +215,7 @@ private: static void* cuDeviceGetAttribute_; static void* cuDeviceGetCount_; // link management + static void* cuLinkAddFile_v2_; static void* cuLinkAddData_v2_; static void* cuLinkCreate_v2_; static void* cuLinkDestroy_; diff --git a/include/triton/external/CUDA/cuda.h b/include/triton/external/CUDA/cuda.h old mode 100755 new mode 100644 index f7bf9fc12..2f32c80fa --- a/include/triton/external/CUDA/cuda.h +++ b/include/triton/external/CUDA/cuda.h @@ -224,7 +224,7 @@ typedef uint64_t cuuint64_t; /** * CUDA API version number */ -#define CUDA_VERSION 11050 +#define CUDA_VERSION 11040 #ifdef __cplusplus extern "C" { @@ -496,33 +496,7 @@ typedef enum CUarray_format_enum { CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */ CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */ CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */ - CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */ - CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */ - CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */ - CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */ - CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */ - CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/ - CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */ - CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */ - CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */ - CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */ - CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */ - CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */ - CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */ - CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */ + CU_AD_FORMAT_NV12 = 0xb0 } CUarray_format; /** @@ -657,7 +631,7 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */ CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ - CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */ CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */ CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */ @@ -665,7 +639,7 @@ typedef enum CUdevice_attribute_enum { CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */ CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */ CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */ - CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */ CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */ @@ -1650,8 +1624,7 @@ typedef enum cudaError_enum { CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, /** - * This indicates that the device kernel source is invalid. This includes - * compilation/linker errors encountered in device code or user error. + * This indicates that the device kernel source is invalid. */ CUDA_ERROR_INVALID_SOURCE = 300, @@ -2068,9 +2041,9 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); * On Windows the flag is a no-op. * On Linux that memory is marked as non cache-coherent for the GPU and * is expected to be physically contiguous. It may return - * ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, - * ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. - * On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED + * CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, + * CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. + * On all other platforms, it is not supported and CUDA_ERROR_NOT_SUPPORTED * is returned. * Flag for ::cuMemHostRegister() */ @@ -2079,12 +2052,12 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); /** * If set, the passed memory pointer is treated as pointing to memory that is * considered read-only by the device. On platforms without -* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is +* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is * required in order to register memory mapped to the CPU as read-only. Support * for the use of this flag can be queried from the device attribute -* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with +* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with * a current context associated with a device that does not have this attribute -* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED. +* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. */ #define CU_MEMHOSTREGISTER_READ_ONLY 0x08 @@ -3735,117 +3708,117 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * \p dev. The supported attributes are: * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per * block; - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid - * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid; + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid; * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of - * shared memory available to a thread block in bytes + * shared memory available to a thread block in bytes; * - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for - * __constant__ variables in a CUDA C kernel in bytes - * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads + * __constant__ variables in a CUDA C kernel in bytes; + * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads; * - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the * memory copy functions that involve memory regions allocated through - * ::cuMemAllocPitch() + * ::cuMemAllocPitch(); * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width - * for a 1D texture bound to linear memory + * for a 1D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum - * mipmapped 1D texture width + * mipmapped 1D texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D - * texture height + * texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width - * for a 2D texture bound to linear memory + * for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height - * for a 2D texture bound to linear memory + * for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch - * in bytes for a 2D texture bound to linear memory + * in bytes for a 2D texture bound to linear memory; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum - * mipmapped 2D texture width + * mipmapped 2D texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum - * mipmapped 2D texture height + * mipmapped 2D texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D - * texture width + * texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D - * texture height + * texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D - * texture depth + * texture depth; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE: * Alternate maximum 3D texture width, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE: * Alternate maximum 3D texture height, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE: * Alternate maximum 3D texture depth, 0 if no alternate - * maximum 3D texture size is supported + * maximum 3D texture size is supported; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH: - * Maximum cubemap texture width or height + * Maximum cubemap texture width or height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH: - * Maximum 1D layered texture width + * Maximum 1D layered texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered texture + * Maximum layers in a 1D layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH: - * Maximum 2D layered texture width + * Maximum 2D layered texture width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT: - * Maximum 2D layered texture height + * Maximum 2D layered texture height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered texture + * Maximum layers in a 2D layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered texture width or height + * Maximum cubemap layered texture width or height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered texture + * Maximum layers in a cubemap layered texture; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH: - * Maximum 1D surface width + * Maximum 1D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH: - * Maximum 2D surface width + * Maximum 2D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT: - * Maximum 2D surface height + * Maximum 2D surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH: - * Maximum 3D surface width + * Maximum 3D surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT: - * Maximum 3D surface height + * Maximum 3D surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH: - * Maximum 3D surface depth + * Maximum 3D surface depth; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH: - * Maximum 1D layered surface width + * Maximum 1D layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS: - * Maximum layers in a 1D layered surface + * Maximum layers in a 1D layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH: - * Maximum 2D layered surface width + * Maximum 2D layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT: - * Maximum 2D layered surface height + * Maximum 2D layered surface height; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS: - * Maximum layers in a 2D layered surface + * Maximum layers in a 2D layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH: - * Maximum cubemap surface width + * Maximum cubemap surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH: - * Maximum cubemap layered surface width + * Maximum cubemap layered surface width; * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS: - * Maximum layers in a cubemap layered surface + * Maximum layers in a cubemap layered surface; * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit - * registers available to a thread block - * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz + * registers available to a thread block; + * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz; * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture * base addresses aligned to ::textureAlign bytes do not need an offset - * applied to texture fetches + * applied to texture fetches; * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement - * for 2D texture references bound to pitched memory + * for 2D texture references bound to pitched memory; * - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy - * memory between host and device while executing a kernel, or 0 if not + * memory between host and device while executing a kernel, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on - * the device + * the device; * - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit - * for kernels executed on the device, or 0 if not + * for kernels executed on the device, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the - * memory subsystem, or 0 if not + * memory subsystem, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host - * memory into the CUDA address space, or 0 if not + * memory into the CUDA address space, or 0 if not; * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently * in. Available modes are as follows: * - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and @@ -3858,33 +3831,33 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * executing multiple kernels within the same context simultaneously, or 0 if * not. It is not guaranteed that multiple kernels will be resident * on the device concurrently so this feature should not be relied upon for - * correctness. + * correctness; * - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the - * device, 0 if error correction is disabled or not supported by the device - * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device + * device, 0 if error correction is disabled or not supported by the device; + * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device; * - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier - * of the device + * of the device; * - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device * - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC - * is only available on Tesla hardware running Windows Vista or later - * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz - * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits - * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache - * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor + * is only available on Tesla hardware running Windows Vista or later; + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz; + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits; + * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache; + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with - * the host, or 0 if not - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number - * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number + * the host, or 0 if not; + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number; + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number; * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals - * in L1 cache, 0 if caching globals in L1 cache is not supported by the device + * in L1 cache, 0 if caching globals in L1 cache is not supported by the device; * - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals - * in L1 cache, 0 if caching locals in L1 cache is not supported by the device + * in L1 cache, 0 if caching locals in L1 cache is not supported by the device; * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of * shared memory available to a multiprocessor in bytes; this amount is shared - * by all thread blocks simultaneously resident on a multiprocessor + * by all thread blocks simultaneously resident on a multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit * registers available to a multiprocessor; this number is shared by all thread - * blocks simultaneously resident on a multiprocessor + * blocks simultaneously resident on a multiprocessor; * - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory * on this system, 0 if allocating managed memory is not supported by the device on this system. * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not. @@ -3910,20 +3883,14 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, * - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes. + * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes. + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor. * - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate - * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes - * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate. - * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes - * - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays. - * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU + * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes. + * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU * - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum - * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. - * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC * * \param pi - Returned device attribute value * \param attrib - Device attribute to query @@ -4690,13 +4657,6 @@ CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArra * It is the responsibility of the calling function to ensure that no API * call issues using \p ctx while ::cuCtxDestroy() is executing. * - * Destroys and cleans up all resources associated with the context. - * It is the caller's responsibility to ensure that the context or its resources - * are not accessed or passed in subsequent API calls and doing so will result in undefined behavior. - * These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent, - * ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref, - * ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore. - * * If \p ctx is current to the calling thread then \p ctx will also be * popped from the current thread's context stack (as though ::cuCtxPopCurrent() * were called). If \p ctx is current to other threads, then \p ctx will @@ -5672,7 +5632,6 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); * ::CUDA_ERROR_INVALID_CONTEXT, * ::CUDA_ERROR_INVALID_VALUE * \notefnerr - * \note_destroy_ub * * \sa ::cuModuleGetFunction, * ::cuModuleGetGlobal, @@ -5993,9 +5952,8 @@ cuLinkDestroy(CUlinkState state); /** * \brief Gets free and total memory * - * Returns in \p *total the total amount of memory available to the the current context. - * Returns in \p *free the amount of memory on the device that is free according to the OS. - * CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free. + * Returns in \p *free and \p *total respectively, the free and total amount of + * memory available for allocation by the CUDA context, in bytes. * * \param free - Returned free memory in bytes * \param total - Returned total memory in bytes @@ -6839,10 +6797,10 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); * * - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory * that is considered read-only by the device. On platforms without - * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is + * CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is * required in order to register memory mapped to the CPU as read-only. Support * for the use of this flag can be queried from the device attribute - * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with + * CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with * a current context associated with a device that does not have this attribute * set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. * @@ -8987,7 +8945,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * float16's: * \code CUDA_ARRAY_DESCRIPTOR desc; - desc.Format = CU_AD_FORMAT_HALF; + desc.FormatFlags = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -8997,7 +8955,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi * of which is two 8-bit unsigned chars: * \code CUDA_ARRAY_DESCRIPTOR arrayDesc; - desc.Format = CU_AD_FORMAT_UNSIGNED_INT8; + desc.FormatFlags = CU_AD_FORMAT_UNSIGNED_INT8; desc.NumChannels = 2; desc.Width = width; desc.Height = height; @@ -9323,7 +9281,7 @@ CUresult CUDAAPI cuArrayDestroy(CUarray hArray); * 4x16-bit float16's: * \code CUDA_ARRAY3D_DESCRIPTOR desc; - desc.Format = CU_AD_FORMAT_HALF; + desc.FormatFlags = CU_AD_FORMAT_HALF; desc.NumChannels = 4; desc.Width = width; desc.Height = height; @@ -15180,7 +15138,7 @@ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, c * \param nodeParams - Parameters for the node * * When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in - * \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. + * \param nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. * * If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode, * the allocation can be accessed by nodes ordered after the allocation node but before the free node. @@ -15356,9 +15314,7 @@ CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dpt * * \sa * ::cuGraphAddMemAllocNode, - * ::cuGraphAddMemFreeNode, - * ::cuDeviceSetGraphMemAttribute, - * ::cuDeviceGetGraphMemAttribute + * ::cuGraphAddMemFreeNode */ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); @@ -15384,7 +15340,6 @@ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); * ::CUDA_ERROR_INVALID_DEVICE * * \sa - * ::cuDeviceSetGraphMemAttribute, * ::cuGraphAddMemAllocNode, * ::cuGraphAddMemFreeNode */ @@ -15409,7 +15364,6 @@ CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attrib * ::CUDA_ERROR_INVALID_DEVICE * * \sa - * ::cuDeviceGetGraphMemAttribute, * ::cuGraphAddMemAllocNode, * ::cuGraphAddMemFreeNode */ diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h index 840145246..26d406baf 100644 --- a/include/triton/ir/basic_block.h +++ b/include/triton/ir/basic_block.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef _TRITON_IR_BASIC_BLOCK_H_ #define _TRITON_IR_BASIC_BLOCK_H_ @@ -27,7 +27,7 @@ public: private: // constructors - basic_block(context &ctx, const std::string &name, function *parent); + basic_block(context &ctx, const std::string &name, function *parent, basic_block *next); public: // accessors @@ -35,6 +35,7 @@ public: context& get_context() { return ctx_; } // get iterator to first instruction that is not a phi + void replace_phi_uses_with(basic_block* before, basic_block* after); iterator get_first_non_phi(); // get instruction list @@ -60,13 +61,16 @@ public: inline const instruction &back() const { return *inst_list_.back(); } inline instruction &back() { return *inst_list_.back(); } + void append_instruction(ir::instruction* i); + // split + basic_block* split_before(ir::instruction* loc, const std::string& name); + // predecessors - const std::vector& get_predecessors() const { return preds_; } - const std::vector& get_successors() const { return succs_; } - void add_predecessor(basic_block* pred); + std::vector get_predecessors() const; + std::vector get_successors() const; // factory functions - static basic_block* create(context &ctx, const std::string &name, function *parent); + static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr); void print(std::ostream &os); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 45a7d5111..ff8447124 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -22,6 +22,7 @@ class phi_node; /* Builder */ class builder{ +public: typedef basic_block::iterator iterator; public: @@ -75,6 +76,7 @@ public: value* create_br(basic_block *dest); value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); + value* create_ret(value *ret); // Cast instructions value *create_cast(cast_op_t op, value *v, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); @@ -86,6 +88,9 @@ public: value* create_fp_trunc(value *src, type *dst_ty); value* create_int_cast(value *src, type *dst_ty, bool is_signed); value *create_downcast(value *arg); + // Call instruction + value* create_call(function* fn, const std::vector& args); + value* create_launch(function* fn, const std::vector& args, const std::vector& grid, value* num_warps); // Phi instruction phi_node* create_phi(type *ty, unsigned num_reserved); // Binary instructions @@ -142,6 +147,9 @@ public: value *create_store(value *ptr, value *val); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); value *create_masked_store(value *ptr, value *val, value *mask); + // Struct instructions + value *create_insert_value(value* val, value *elt, size_t idx); + value *create_extract_value(value* val, size_t idx); // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 081ea249d..619ae4c87 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -31,7 +31,8 @@ public: std::map, std::unique_ptr> ptr_tys; // Block types std::map, std::unique_ptr> block_tys; - + // Struct types + std::map struct_tys; // Int constants std::map, std::unique_ptr> int_constants_; // Float constants diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 2d4c09d79..3fa008606 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -95,6 +95,9 @@ enum value_id_t: unsigned { INSTRUCTIONS * ------------ */ INST_BEGIN, + // call + INST_CALL, + INST_LAUNCH, // phi INST_PHI, // arithmetic @@ -129,6 +132,9 @@ enum value_id_t: unsigned { INST_MASKED_LOAD_ASYNC, INST_UNMASKED_STORE, INST_MASKED_STORE, + // struct + INST_EXTRACT_VALUE, + INST_INSERT_VALUE, // retile INST_RESHAPE, INST_SPLAT, diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 9e1bc981a..4e76e60a4 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -24,7 +24,7 @@ public: static argument* create(type *ty, const std::string &name, function *parent = nullptr, unsigned arg_no = 0); function* get_parent() const; - unsigned get_arg_no() const; + unsigned get_arg_no() const; void accept(visitor *v); @@ -121,6 +121,8 @@ public: const attr_map_t &attrs() { return attrs_; } bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } std::set get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; } + void set_is_kernel(bool new_val) { is_kernel_ = new_val; } + bool get_is_kernel() { return is_kernel_; } void print(std::ostream &os); @@ -134,6 +136,7 @@ private: args_t args_; blocks_t blocks_; attr_map_t attrs_; + bool is_kernel_; }; } diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index e9e0f0f11..c2d427ae8 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -81,6 +81,51 @@ private: value_id_t id_; }; +//===----------------------------------------------------------------------===// +// call_inst classes +//===----------------------------------------------------------------------===// + +class call_inst: public instruction { +private: + std::string repr_impl() const; + call_inst(ir::function* fn, const std::vector& values, const std::string& name, instruction* next); + +public: + static call_inst* create(ir::function* fn, const std::vector& values, const std::string &name = "", instruction *next = nullptr); + ir::function* get_fn() { return fn_; } + + _TRITON_DEFINE_CLONE(call_inst) + _TRITON_DEFINE_ACCEPT(call_inst) + +private: + ir::function* fn_; +}; + +class launch_inst: public instruction { +private: + std::string repr_impl() const { return "launch"; } + launch_inst(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, + const std::string &name = "", instruction *next = nullptr); + +public: + static launch_inst* create(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, + const std::string& name = "", instruction* next = nullptr); + + ir::function* get_fn(); + std::vector get_values(); + std::vector get_grid(); + ir::value* get_num_warps(); + + + _TRITON_DEFINE_CLONE(launch_inst) + _TRITON_DEFINE_ACCEPT(launch_inst) + +private: + unsigned val_begin; + unsigned val_end; + unsigned grid_begin; + unsigned grid_end; +}; //===----------------------------------------------------------------------===// // phi_node classes @@ -546,6 +591,44 @@ public: _TRITON_DEFINE_ACCEPT(masked_store_inst) }; +//===----------------------------------------------------------------------===// +// struct classes +//===----------------------------------------------------------------------===// + +// insert_value + +class insert_value_inst: public instruction { +private: + std::string repr_impl() const { return "insertvalue"; } + insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next); + +public: + static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr); + size_t get_idx() { return idx_; } + _TRITON_DEFINE_CLONE(insert_value_inst) + _TRITON_DEFINE_ACCEPT(insert_value_inst) + +private: + size_t idx_; +}; + +// extract_value + +class extract_value_inst: public instruction { +private: + std::string repr_impl() const { return "extractvalue"; } + extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next); + +public: + static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr); + size_t get_idx() { return idx_; } + _TRITON_DEFINE_CLONE(extract_value_inst) + _TRITON_DEFINE_ACCEPT(extract_value_inst) + +private: + size_t idx_; +}; + //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 30881fd49..f8f033eb7 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -34,79 +34,97 @@ class constant; class global_value; class alloc_const; -/* Module */ - -class module { +class value_constructor { typedef std::pair val_key_t; - friend class function; typedef std::pair md_pair_t; -public: - typedef std::map symbols_map_t; - typedef std::vector functions_list_t; - struct current_iteration_info_t{ - lang::iteration_statement *statement; - basic_block *block; - }; - private: phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); value *try_remove_trivial_phis(ir::phi_node *&phi); value *add_phi_operands(const std::string& name, phi_node *&phi); value *get_value_recursive(const std::string& name, basic_block *block); + +public: + value_constructor(builder &builder); + + void set_value(const std::string& name, basic_block* block, value *x); + void set_value(const std::string& name, value* x); + const std::map& get_values() { return values_; } + void set_values(const std::map& values) { values_ = values; } + value *get_value(const std::string& name, basic_block* block); + value *get_value(const std::string& name); + void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } + // Seal block -- no more predecessors will be added + void seal_block(basic_block *block); + // Metadata + void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } + +private: + ir::builder& builder_; + std::map values_; + std::map types_; + std::set sealed_blocks_; + std::map> incomplete_phis_; + std::map current_phi_; + std::map metadatas_; +}; + +/* Module */ + +class module { + typedef std::pair val_key_t; + friend class function; + +public: + typedef std::map symbols_map_t; + typedef std::vector functions_list_t; + +private: void push_function(function *fn) { functions_.push_back(fn); } public: module(const std::string &name, builder& builder); builder& get_builder(); // Setters - void set_value(const std::string& name, basic_block* block, value *x); - void set_value(const std::string& name, value* x); - void set_const(const std::string& name); void set_continue_fn(std::function fn); // Getters - const std::map& get_values() { return values_; } - const std::map& get_types() { return types_; } - void set_values(const std::map& values) { values_ = values; } - void set_types(const std::map& types) { types_ = types; } - - value *get_value(const std::string& name, basic_block* block); - value *get_value(const std::string& name); - void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } const std::string& get_name(); std::function get_continue_fn(); - // Seal block -- no more predecessors will be added - void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } + function *get_function(const std::string& name) { + if(symbols_.find(name) == symbols_.end()) + throw std::runtime_error("function " + name + " is not declared"); + return (function*)symbols_.at(name); + } function *get_or_insert_function(const std::string &name, function_type *ty); + bool has_function(const std::string& name){ + return symbols_.find(name) != symbols_.end(); + } + void remove_function(ir::function* fn){ + functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end()); + } + + void reset_ret_ty(const std::string& name, type* ty); + // Const allocation void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } const std::vector& allocs() { return allocs_; } // Register global void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } const std::map& globals() const { return globals_; } - // Metadata - void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - + // void print(std::ostream &os); private: std::string name_; builder& builder_; - std::map values_; - std::map types_; - std::set const_; - std::set sealed_blocks_; - std::map> incomplete_phis_; functions_list_t functions_; symbols_map_t symbols_; std::function continue_fn_; - std::map current_phi_; std::vector allocs_; std::map globals_; - std::map metadatas_; }; } diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 47c9b5f85..d7919b4c8 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #ifndef _TRITON_IR_TYPE_H_ #define _TRITON_IR_TYPE_H_ @@ -73,6 +73,8 @@ public: type *get_tile_element_ty() const; unsigned get_pointer_address_space() const; type *get_pointer_element_ty() const; + unsigned get_struct_numel() const { return contained_tys_.size(); } + type *get_struct_type(unsigned int i) const { return contained_tys_[i]; } // primitive predicates bool is_void_ty() const { return id_ == VoidTyID; } @@ -91,6 +93,7 @@ public: bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } + bool is_struct_ty() const { return id_ == StructTyID; } // Composite predicates bool is_int_or_tileint_ty(); @@ -138,10 +141,10 @@ public: switch(id_) { case VoidTyID: return "void"; case FP8TyID: return "fp8"; + case BF16TyID: return "bf16"; case FP16TyID: return "f16"; case FP32TyID: return "f32"; case FP64TyID: return "f64"; - case BF16TyID: return "bf16"; case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; @@ -194,6 +197,16 @@ public: type* get_type_at_index(value *idx) const; }; +class struct_type: public composite_type { +public: + struct_type(const contained_tys_vec_t& tys, bool is_packed); + unsigned get_num_types() const { return contained_tys_.size(); } + static struct_type* get(const contained_tys_vec_t& tys, bool is_packed); + +private: + bool is_packed_; +}; + class block_type: public composite_type { private: block_type(type *ty, const block_shapes_t &shapes); @@ -242,6 +255,7 @@ public: ty_iterator params_end() { return contained_tys_.end(); } type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); } type* get_return_ty() const { return contained_tys_.at(0); } + void reset_ret_ty(type* ty) { contained_tys_[0] = ty;} // factory methods static function_type* get(type *ret_ty, const std::vector& param_tys); }; diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index 7a132d5e2..fde09121a 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -21,7 +21,7 @@ class visitor; class value { public: - typedef std::set users_t; + typedef std::vector users_t; public: // constructor @@ -30,7 +30,7 @@ public: // uses void add_use(user* arg); users_t::iterator erase_use(user* arg); - const std::set &get_users() { return users_; } + const std::vector &get_users() { return users_; } void replace_all_uses_with(value *target); // name void set_name(const std::string &name); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 25ce578e3..774f2e172 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -11,6 +11,9 @@ class value; class instruction; +class call_inst; +class launch_inst; + class phi_node; class binary_operator; class getelementptr_inst; @@ -42,6 +45,9 @@ class masked_load_inst; class unmasked_store_inst; class masked_store_inst; +class extract_value_inst; +class insert_value_inst; + class retile_inst; class reshape_inst; class splat_inst; @@ -105,6 +111,8 @@ public: virtual ~visitor() {} virtual void visit_value(ir::value*); + virtual void visit_call_inst(ir::call_inst*) = 0; + virtual void visit_launch_inst(ir::launch_inst*) = 0; virtual void visit_basic_block(basic_block*) = 0; virtual void visit_argument(argument*) = 0; @@ -132,6 +140,9 @@ public: virtual void visit_sin_inst(sin_inst*) = 0; virtual void visit_log_inst(log_inst*) = 0; + virtual void visit_extract_value_inst(extract_value_inst*) = 0; + virtual void visit_insert_value_inst(insert_value_inst*) = 0; + virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0; virtual void visit_cat_inst(cat_inst*) = 0; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 5d30a2f45..cec512fec 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -608,6 +608,8 @@ void layouts::run(ir::module &mod) { // create temporaries size_t id = values_.size(); ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { +// std::cout << "layout: " << std::endl; +// i->print(std::cout); if(auto *red = dynamic_cast(i)) { id++; ir::value *arg = red->get_operand(0); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 8921d6c84..e2cd6d228 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -13,6 +13,7 @@ #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/prefetch.h" +#include "triton/codegen/transform/inline.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/print.h" @@ -33,6 +34,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; + codegen::transform::inliner inliner; codegen::analysis::axes axes; codegen::transform::cts cts(cts_use_async); codegen::transform::pipeline pipeline(cts_use_async, num_stages); @@ -48,6 +50,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); // run passes + inliner.run(ir); dce.run(ir); peephole.run(ir); dce.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b36f51d92..0e6ae4539 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -13,6 +13,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" +#include "triton/ir/utils.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsNVPTX.h" @@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ * \brief Convert Triton-IR Type to LLVM-IR Type */ Type *generator::cvt(ir::type *ty) { + // struct + if(ty->is_struct_ty()){ + std::vector tys; + for(size_t i = 0; i < ty->get_struct_numel(); i++) + tys.push_back(cvt(ty->get_struct_type(i))); + return StructType::get(builder_->getContext(), tys, true); + } + // function if(auto* tt = dynamic_cast(ty)){ Type *ret_ty = cvt(tt->get_return_ty()); @@ -266,7 +275,8 @@ void generator::visit_value(ir::value* v) { builder_->SetInsertPoint(&*current->getFirstNonPHI()); // visit user if(auto *usr = dynamic_cast(v)){ - usr->accept(this); + if(!dynamic_cast(usr)) + usr->accept(this); } // revert insert point if(phi && !current->empty() && current->getFirstNonPHI()) @@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) { vals_[x][idx] = phi(ty, x->get_num_operands()); } +/** + * \brief Code Generation for `call` + */ +void generator::visit_call_inst(ir::call_inst* call) { + throw std::runtime_error("call not supported! Triton should be inlining everything."); +} + +void generator::visit_launch_inst(ir::launch_inst *launch) { + ir::function* fn = (ir::function*)launch->get_operand(0); + // forward-declare cudaGetParameterBufferV2 + std::vector get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0), + ArrayType::get(builder_->getInt32Ty(), 3), + ArrayType::get(builder_->getInt32Ty(), 3), + builder_->getInt32Ty()}; + FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false); + Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_); + AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]); + AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]); + ConstantInt* _0 = builder_->getInt32(0); + ConstantInt* _1 = builder_->getInt32(1); + ConstantInt* _2 = builder_->getInt32(2); + // create basic block + BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent()); + BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb); + Value *tid = tgt_->get_local_id(mod_, *builder_, 0); + Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0)); + builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb); + builder_->SetInsertPoint(launch_bb); + + // + builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0})); + builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1})); + builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2})); + Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]); + builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0})); + builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1})); + builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2})); + Function* called_fn = fns_[fn]; + Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]); + Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)}); + // forwrd-declare cudaLaunchDeviceV2 + std::vector launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()}; + FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false); + Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_); + // TODO: add branch + Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()), + builder_->getInt64(0)); + BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb); + builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb); + builder_->SetInsertPoint(launch2_bb); + + unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace(); + unsigned off = 0; + unsigned last_size = 0; + for(ir::value* arg: launch->get_values()){ + Value* curr_arg = vals_[arg][{}]; + Type* curr_arg_ty = curr_arg->getType(); + // handle struct alignment + off += last_size; + unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8; + off = (off + size - 1) / size * size; + // get pointer to current arg + Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off)); + curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space)); + // store arg + builder_->CreateStore(curr_arg, curr_arg_ptr); + last_size = size; + } + builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)}); + builder_->CreateBr(launch_done_bb); + // done + builder_->SetInsertPoint(launch_done_bb); + +} + /** * \brief Code Generation for `binary_operator` */ @@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) { default: throw std::runtime_error("unreachable switch"); } }; +// x->print(std::cout); for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; @@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) { visit_store_inst(x); } +// -- + +void generator::visit_extract_value_inst(ir::extract_value_inst *x) { + auto idxs = idxs_.at(x); + ir::value* agg = x->get_operand(0); + unsigned insert_idx = x->get_idx(); + for(size_t i = 0; i < idxs.size(); i++){ + auto idx = idxs[i]; + vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx}); + } +} + + +void generator::visit_insert_value_inst(ir::insert_value_inst *x){ + auto idxs = idxs_.at(x); + ir::value* agg = x->get_operand(0); + ir::value* val = x->get_operand(1); + unsigned insert_idx = x->get_idx(); + for(size_t i = 0; i < idxs.size(); i++){ + auto idx = idxs[i]; + vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx}); + } +} + +// -- /** * \brief Code Generation for `cat` */ @@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) { } void generator::visit_undef_value(ir::undef_value *x) { - Type* ty = cvt(x->get_type()->get_scalar_ty()); + ir::type* sca_ty = x->get_type()->get_scalar_ty(); + Type* ty = cvt(sca_ty); for(indices_t idx: idxs_.at(x)) vals_[x][idx] = llvm::UndefValue::get(ty); } @@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) { } -void generator::visit_function(ir::function* fn) { - LLVMContext &ctx = builder_->getContext(); +void generator::forward_declare(ir::function* fn){ FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); if(!tgt_->is_gpu()){ Type *fn_ret_ty = fn_ty->getReturnType(); @@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) { fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); } Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); + fns_[fn] = ret; +} + +void generator::visit_function(ir::function* fn) { + idxs_.clear(); + vals_.clear(); + seen_.clear(); + LLVMContext &ctx = builder_->getContext(); + + Function* ret = fns_[fn]; + + // set attributes for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; @@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) { for(unsigned i = 0; i < fn->args().size(); i++) vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); // create blocks - for(ir::basic_block *block: fn->blocks()) { + auto blocks = ir::cfg::reverse_post_order(fn); + for(ir::basic_block *block: blocks) { BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); bbs_[block] = dst_block; } @@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) { visit_layout(x.second); } // generate LLVM-IR code - for(ir::basic_block *block: fn->blocks()) + for(ir::basic_block *block: blocks) visit_basic_block(block); // finalize finalize_function(fn); @@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { } void generator::visit_basic_block(ir::basic_block * block) { + BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()) + for(ir::instruction *i: block->get_inst_list()){ visit_value(i); + } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); } @@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) { } } +StructType* generator::packed_type(ir::value* i){ + Type* dtype = cvt(i->get_type()->get_tile_element_ty()); + auto* layout = dynamic_cast(layouts_->get(i)); + assert(layout); +} + void generator::visit(ir::module &src, llvm::Module &dst) { mod_ = &dst; ctx_ = &dst.getContext(); @@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) { nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); shmem_ = bit_cast(sh_mem_array, ptr_ty); } + // instantiate device functions +// for(ir::function *fn: src.get_function_list()) +// for(ir::basic_block *bb: fn->blocks()) +// for(ir::instruction *i: bb->get_inst_list()) +// if(auto *call = dynamic_cast(i)){ +// std::cout << "call??" << std::endl; +// } // visit functions + for(ir::function *fn: src.get_function_list()) + forward_declare(fn); for(ir::function *fn: src.get_function_list()) visit_function(fn); } diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc index c555290f8..7416ff6e8 100644 --- a/lib/codegen/transform/dce.cc +++ b/lib/codegen/transform/dce.cc @@ -3,6 +3,7 @@ #include "triton/ir/basic_block.h" #include "triton/ir/module.h" #include "triton/ir/utils.h" +#include namespace triton { namespace codegen{ @@ -28,6 +29,8 @@ void dce::run(ir::module &mod) { case ir::INST_ATOMIC_CAS: case ir::INST_ATOMIC_RMW: case ir::INST_ATOMIC_EXCH: + case ir::INST_CALL: + case ir::INST_LAUNCH: case ir::INST_BARRIER: { work_list.push_back(i); marked.insert(i); @@ -65,6 +68,7 @@ void dce::run(ir::module &mod) { } } + // delete for(ir::instruction* i: to_delete) i->erase_from_parent(); diff --git a/lib/codegen/transform/inline.cc b/lib/codegen/transform/inline.cc new file mode 100644 index 000000000..fa22e5354 --- /dev/null +++ b/lib/codegen/transform/inline.cc @@ -0,0 +1,127 @@ +#include +#include "triton/codegen/transform/inline.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/utils.h" + +namespace triton{ +namespace codegen{ +namespace transform{ + + +bool fncmp::operator()(ir::function* x, ir::function* y) const { + auto fn_list = x->get_parent()->get_function_list(); + return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y); +}; + +void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, + std::list& callsites){ + ir::basic_block* parent_block = callsite->get_parent(); + ir::function* parent_fn = parent_block->get_parent(); + // the parent block is split into block A and block B: + // - block A (`new_blocks[0]`) is the entry block of the inlined function + // - block B (`exit`) resumes execution of the parent function + ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name()); + ir::basic_block* exit = entry->get_successors()[0]; + std::vector new_blocks = {entry}; + for(size_t i = 1; i < fn->blocks().size(); i++){ + ir::basic_block* block = fn->blocks()[i]; + ir::context& ctx = block->get_context(); + const std::string& name = block->get_parent()->get_name() + "_" + block->get_name(); + new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn)); + } + // a phi node holds the return values of the inlined function + if(exit->get_inst_list().empty()) + builder.set_insert_point(exit); + else + builder.set_insert_point(exit->get_first_non_phi()); + ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0); + callsite->replace_all_uses_with(exit_val); + callsite->erase_from_parent(); + // get arguments `fn` is called with + std::vector tgt_args(callsite->op_begin(), callsite->op_end()); + std::vector src_args(fn->args().begin(), fn->args().end()); + // Actually generate the instructions: + // - Remove the branch created by basic_block::split_before + // - Clone all instructions + // - Replace `ret` with incoming nodes to `exit_val` and branches to `exit` + ir::instruction* terminator = new_blocks[0]->get_inst_list().back(); +// new_blocks[0]->get_inst_list().back()->erase_from_parent(); + terminator->erase_from_parent(); + std::map inst_map; + std::map arg_map; + for(size_t k = 0; k < fn->args().size(); k++) + arg_map[fn->args()[k]] = callsite->ops()[k]; + std::vector rpo = ir::cfg::reverse_post_order(fn); + for(size_t i = 0; i < new_blocks.size(); i++){ + ir::basic_block* old_block = fn->blocks()[i]; + ir::basic_block* new_block = new_blocks[i]; + builder.set_insert_point(new_block); + for(ir::instruction* old_inst: old_block->get_inst_list()){ + // clone instruction + ir::instruction* new_inst = old_inst->clone(); + // replace basic block + for(size_t k = 0; k < new_blocks.size(); k++) + new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]); + // replace values + for(size_t k = 0; k < new_inst->get_num_operands(); k++){ + ir::value* op = new_inst->get_operand(k); + if(auto arg_op = dynamic_cast(op)) + new_inst->set_operand(k, arg_map.at(arg_op)); + if(auto inst_op = dynamic_cast(op)) + if(inst_map.find(inst_op) != inst_map.end()) + new_inst->set_operand(k, inst_map.at(inst_op)); + } + // `ret` instruction is a special case: + // instead of returning we need to branch to after the function call + if(ir::return_inst* ret = dynamic_cast(new_inst)){ + if(ir::value* ret_val = ret->get_return_value()) + exit_val->add_incoming(ret_val, new_block); + new_inst = ir::branch_inst::create(exit); + } + inst_map[old_inst] = new_inst; + builder.insert(new_inst); + } + } + if(exit_val->get_num_incoming() == 1) + exit_val->replace_all_uses_with(exit_val->get_incoming_value(0)); + // done -- make sure insert point is properly set to exit block + builder.set_insert_point(exit); +} + +void inliner::run(ir::module &mod) { + + // gather all call sites + while(true){ + std::map counts; + for(ir::function* fn: mod.get_function_list()) + counts[fn] = 0; + + std::list callsites; + for(ir::function* fn: mod.get_function_list()){ + for(ir::basic_block* block: fn->blocks()) + for(ir::instruction* instr: block->get_inst_list()) + if(ir::call_inst* call = dynamic_cast(instr)){ + callsites.push_back(call); + counts[call->get_fn()] += 1; + } + } + + for(auto& count: counts){ + if(!count.first->get_is_kernel() && count.second == 0) + count.first->get_parent()->remove_function(count.first); + } + + if(callsites.empty()) + break; + + for(ir::call_inst* call: callsites) + do_inline(call->get_fn(), call, mod.get_builder(), callsites); + } + + +} + +} +} +} diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 0961efc9c..c25a252a8 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -150,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ } bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) { - auto binop = dynamic_cast(value); - if(binop && binop->get_op() == ir::binary_op_t::Mul) { - ir::value *lhs = binop->get_operand(0); - ir::value *rhs = binop->get_operand(1); - ir::constant_int *_1_lhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(lhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_lhs = cst; - } - ir::constant_int *_1_rhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(rhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_rhs = cst; - } - if(_1_lhs){ - binop->replace_all_uses_with(rhs); - return true; - } - else if(_1_rhs){ - binop->replace_all_uses_with(lhs); - return true; - } + auto binop = dynamic_cast(value); + if(binop && binop->get_op() == ir::binary_op_t::Mul) { + ir::value *lhs = binop->get_operand(0); + ir::value *rhs = binop->get_operand(1); + ir::constant_int *_1_lhs = nullptr; + if(ir::splat_inst *splat = dynamic_cast(lhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_lhs = cst; } + ir::constant_int *_1_rhs = nullptr; + if(ir::splat_inst *splat = dynamic_cast(rhs)){ + auto *cst = dynamic_cast(splat->get_operand(0)); + if(cst && cst->get_value() == 1) + _1_rhs = cst; + } + if(_1_lhs){ + binop->replace_all_uses_with(rhs); + return true; + } + else if(_1_rhs){ + binop->replace_all_uses_with(lhs); + return true; + } + } + return false; +} + +bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){ + auto extracted = dynamic_cast(value); + if(!extracted) return false; + size_t extract_idx = extracted->get_idx(); + ir::value* agg = extracted->get_operand(0); + auto insert = dynamic_cast(agg); + while(insert){ + agg = insert->get_operand(0); + ir::value* inserted = insert->get_operand(1); + size_t insert_idx = insert->get_idx(); + insert = dynamic_cast(agg); + if(extract_idx == insert_idx){ + extracted->replace_all_uses_with(inserted); + return true; + } + insert = dynamic_cast(agg); + } + return false; } @@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) { was_modified = was_modified || rewrite_mult(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_trans_phi(i, builder); + was_modified = was_modified || rewrite_insert_extract(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index c85ba43a1..0c5c0b292 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); const int num_stages = num_stages_; std::vector>> preheader_loads; // Used to reorder loads + for(auto info: to_pipeline){ ir::load_inst* load = info.load; ir::phi_node* ptr = info.ptr; diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 9e2aca432..de6f1901b 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -138,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*) // link management +CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **); CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*); CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState); diff --git a/lib/driver/error.cc b/lib/driver/error.cc index f723351c2..fda2b7f33 100755 --- a/lib/driver/error.cc +++ b/lib/driver/error.cc @@ -90,7 +90,7 @@ void check(CUresult err) case CUDA_ERROR_NOT_PERMITTED : throw not_permitted(); case CUDA_ERROR_NOT_SUPPORTED : throw not_supported(); case CUDA_ERROR_UNKNOWN : throw unknown(); - default : throw unknown(); + default : throw std::runtime_error("unimplemented code: " + std::to_string(err)); } } diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 463f45712..92a6b75de 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -174,6 +174,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; +// pm.add(llvm::createPrintModulePass(llvm::outs())); pm.add(llvm::createVerifierPass()); pm.run(*module); // module->print(llvm::outs(), nullptr); @@ -213,6 +214,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ return result; } + std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { // compile ptx with ptxas char _fsrc[L_tmpnam]; diff --git a/lib/ir/basic_block.cc b/lib/ir/basic_block.cc index 0654156a3..93caef2c3 100644 --- a/lib/ir/basic_block.cc +++ b/lib/ir/basic_block.cc @@ -1,3 +1,5 @@ +#include +#include #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" @@ -9,23 +11,68 @@ namespace ir { class phi_node; -basic_block::basic_block(context &ctx, const std::string &name, function *parent): +basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next): value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) { if(parent_) - parent_->insert_block(this); + parent_->insert_block(this, next); } -basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){ - return new basic_block(ctx, name, parent); +basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){ + return new basic_block(ctx, name, parent, next); } -void basic_block::add_predecessor(basic_block *pred) { - preds_.push_back(pred); - if(pred) - pred->succs_.push_back(this); +void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) { + for(ir::instruction* i: inst_list_){ + auto* curr_phi = dynamic_cast(i); + if(!curr_phi) + break; + curr_phi->replace_uses_of_with(before, after); + } } +void basic_block::append_instruction(ir::instruction* i){ + i->set_parent(this); + inst_list_.push_back(i); +} +basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) { + basic_block* ret = basic_block::create(ctx_, name, parent_, this); + ret->set_name(get_name()); + set_name("after_" + name); + + // splice instruction list + auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc); + ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it); + for(ir::instruction* i: ret->get_inst_list()) + i->set_parent(ret); + // the predecessors of `this` becomes the predecessors of `ret` + for(ir::basic_block* pred: get_predecessors()){ + auto* term = dynamic_cast(pred->get_inst_list().back()); + assert(term); + term->replace_uses_of_with(this, ret); + replace_phi_uses_with(pred, ret); + } + ir::branch_inst* br = branch_inst::create(this); + ret->append_instruction(br); + return ret; +} + +std::vector basic_block::get_predecessors() const { + std::vector ret; + for(ir::user* u: users_) + if(auto term = dynamic_cast(u)) + ret.push_back(term->get_parent()); + return ret; +} + +std::vector basic_block::get_successors() const { + std::vector ret; + for(ir::instruction* i: inst_list_) + for(ir::value* v: i->ops()) + if(auto block = dynamic_cast(v)) + ret.push_back(block); + return ret; +} basic_block::iterator basic_block::get_first_non_phi(){ auto it = begin(); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index fff73e665..58174aa7a 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -117,13 +117,10 @@ type *builder::get_double_ty() //===----------------------------------------------------------------------===// value* builder::create_br(basic_block *dest){ - dest->add_predecessor(block_); return insert(branch_inst::create(dest)); } value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){ - if_dest->add_predecessor(block_); - else_dest->add_predecessor(block_); return insert(branch_inst::create(cond, if_dest, else_dest)); } @@ -131,6 +128,10 @@ value *builder::create_ret_void() { return insert(return_inst::create(ctx_)); } +value *builder::create_ret(value* val) { + return insert(return_inst::create(ctx_, val)); +} + //===----------------------------------------------------------------------===// // cast instructions //===----------------------------------------------------------------------===// @@ -163,6 +164,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){ return insert(phi_node::create(ty, num_reserved)); } +//===----------------------------------------------------------------------===// +// call instructions +//===----------------------------------------------------------------------===// + +value *builder::create_call(function* fn, const std::vector& args){ + return insert(call_inst::create(fn, args)); +} + +value* builder::create_launch(function* fn, const std::vector& args, const std::vector& grid, value* num_warps){ + return insert(launch_inst::create(fn, args, grid, num_warps)); + +} + //===----------------------------------------------------------------------===// // binary float instructions //===----------------------------------------------------------------------===// @@ -307,6 +321,19 @@ value *builder::create_masked_store(value *ptr, value *val, value *mask){ return insert(masked_store_inst::create(ptr, val, mask)); } +//===----------------------------------------------------------------------===// +// struct instructions +//===----------------------------------------------------------------------===// + + +// Struct instructions +value *builder::create_insert_value(value* val, value *elt, size_t idx){ + return insert(insert_value_inst::create(val, elt, idx)); +} + +value *builder::create_extract_value(value* val, size_t idx) { + return insert(extract_value_inst::create(val, idx)); +} //===----------------------------------------------------------------------===// // block instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index aabbc4385..fc0252dbf 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -865,6 +865,9 @@ ir::value *dispatch::clock(ir::builder *builder) { } +//===----------------------------------------------------------------------===// +// Control FLow +//===----------------------------------------------------------------------===// // ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ diff --git a/lib/ir/function.cc b/lib/ir/function.cc index 84d52df72..4f3cd5ac6 100644 --- a/lib/ir/function.cc +++ b/lib/ir/function.cc @@ -33,8 +33,10 @@ void argument::accept(visitor *v) { /* function */ function::function(function_type *ty, linkage_types_t linkage, const std::string &name, module *parent) - : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) { + : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) { unsigned num_params = fn_ty_->get_num_params(); + if(parent) + parent->push_function(this); // skip if no parameter if(num_params == 0) return; @@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage, type *param_ty = fn_ty_->get_param_ty(i); args_[i] = argument::create(param_ty, "", this, i); } - if(parent) - parent->push_function(this); } /* basic block */ diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index d1f81f136..1bcbfa9ff 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -5,6 +5,7 @@ #include "triton/ir/instructions.h" #include "triton/ir/constant.h" #include "triton/ir/type.h" +#include "triton/ir/function.h" namespace triton{ namespace ir{ @@ -79,6 +80,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n return new phi_node(ty, num_reserved, name, next); } +//===----------------------------------------------------------------------===// +// call_inst classes +//===----------------------------------------------------------------------===// + +std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); } + +call_inst::call_inst(ir::function* fn, const std::vector& values, const std::string& name, instruction* next) + : instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){ + for(size_t i = 0; i < values.size(); i++) + set_operand(i, values.at(i)); +} + +call_inst* call_inst::create(ir::function* fn, const std::vector& values, const std::string &name, instruction *next) { + return new call_inst(fn, values, name, next); +} + + +// launch + +launch_inst::launch_inst(ir::function* fn, const std::vector& values, const std::vector& grid, ir::value* num_warps, const std::string& name, instruction* next) + : instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){ + int k = 0; + if(grid.size() != 3) + throw std::runtime_error("grid must have 3 elements"); + set_operand(k++, fn); + val_begin = k; + for(ir::value* v: values) + set_operand(k++, v); + val_end = k; + grid_begin = k; + for(ir::value* g: grid) + set_operand(k++, g); + grid_end = k; + set_operand(k++, num_warps); +} + + +ir::function* launch_inst::get_fn() { + return (ir::function*)get_operand(0); +} + +std::vector launch_inst::get_values() { + std::vector ret; + for(int i = val_begin; i < val_end; i++) + ret.push_back(get_operand(i)); + return ret; +} + +std::vector launch_inst::get_grid() { + std::vector ret; + for(int i = grid_begin; i < grid_end; i++) + ret.push_back(get_operand(i)); + return ret; +} + +ir::value* launch_inst::get_num_warps() { + return get_operand(grid_end); +} + + +launch_inst* launch_inst::create(ir::function *fn, const std::vector &values, const std::vector &grid, ir::value *num_warps, const std::string &name, instruction *next) { + return new launch_inst(fn, values, grid, num_warps, name, next); +} + //===----------------------------------------------------------------------===// // binary_operator classes @@ -324,7 +389,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, // return_inst return_inst::return_inst(context &ctx, value *ret_val, instruction *next) - : terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ + : terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ if(ret_val) set_operand(0, ret_val); } @@ -521,6 +586,36 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { return new masked_store_inst(ptr, val, mask, name, next); } + +//===----------------------------------------------------------------------===// +// struct classes +//===----------------------------------------------------------------------===// + +// insert value + +insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next) + : instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) { + set_operand(0, val); + set_operand(1, elt); +} + +insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){ + return new insert_value_inst(val, elt, idx, name, next); +} + + +// extract value + +extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next) + : instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) { + set_operand(0, val); +} + +extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){ + return new extract_value_inst(val, idx, name, next); +} + + //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// @@ -575,6 +670,9 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next); } + + + //===----------------------------------------------------------------------===// // matmul_inst classes //===----------------------------------------------------------------------===// diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 33b39de3a..7df196c8f 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,17 +9,12 @@ namespace triton{ namespace ir{ -/* Module */ -module::module(const std::string &name, builder &builder) - : name_(name), builder_(builder) { +/* */ +value_constructor::value_constructor(ir::builder& builder): builder_(builder){ sealed_blocks_.insert(nullptr); } -ir::builder& module::get_builder() { - return builder_; -} - -void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ +void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ values_[val_key_t{name, block}] = value; auto it = metadatas_.find(name); if(auto *x = dynamic_cast(value)) @@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu // value->set_name(name); } -void module::set_value(const std::string& name, ir::value *value){ +void value_constructor::set_value(const std::string& name, ir::value *value){ return set_value(name, builder_.get_insert_block(), value); } -void module::set_const(const std::string& name){ - const_.insert(name); -} - -void module::set_continue_fn(std::function fn) { - continue_fn_ = fn; -} - -std::function module::get_continue_fn() { - return continue_fn_; -} - -ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ +ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ basic_block::iterator insert = block->get_first_non_phi(); if(insert != block->end()){ builder_.set_insert_point(insert); @@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc return res; } -ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ +ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){ // find non-self references std::set non_self_ref; std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), @@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ assert(same != nullptr); phi->replace_all_uses_with(same); phi->erase_from_parent(); - std::set users = phi->get_users(); + std::vector users = phi->get_users(); for(ir::user* u: users) if(auto *uphi = dynamic_cast(u)) if(uphi != phi) @@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ } -ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ +ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){ // already initialized if(phi->get_num_operands()) return phi; @@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi) return phi; } -ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { +ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) { ir::value *result; - bool is_const = const_.find(name) != const_.end(); - auto &preds = block->get_predecessors(); + auto preds = block->get_predecessors(); ir::type *ty = types_.at(name); - if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ + if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); result = (ir::value*)incomplete_phis_[block][name]; } @@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block return result; } -ir::value *module::get_value(const std::string& name, ir::basic_block *block) { +ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) { ir::basic_block* save_block = builder_.get_insert_block(); ir::basic_block::iterator save_pt = builder_.get_insert_point(); val_key_t key(name, block); +// std::cout << values_.size() << std::endl; +// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl; if(values_.find(key) != values_.end()){ return values_.at(key); } @@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) { return result; } -ir::value *module::get_value(const std::string& name) { +ir::value *value_constructor::get_value(const std::string& name) { return get_value(name, builder_.get_insert_block()); } -const std::string& module::get_name() { - return name_; -} - -void module::seal_block(ir::basic_block *block){ +void value_constructor::seal_block(ir::basic_block *block){ for(auto &x: incomplete_phis_[block]){ add_phi_operands(x.first, x.second); if(get_value(x.first) == x.second) @@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){ incomplete_phis_[block].clear(); } + + +/* Module */ + +module::module(const std::string &name, builder &builder) + : name_(name), builder_(builder) { +} + +void module::reset_ret_ty(const std::string& name, type* ty) { + get_function(name)->get_fn_type()->reset_ret_ty(ty); +} + +ir::builder& module::get_builder() { + return builder_; +} + +void module::set_continue_fn(std::function fn) { + continue_fn_ = fn; +} + +std::function module::get_continue_fn() { + return continue_fn_; +} + +const std::string& module::get_name() { + return name_; +} + /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; - if(fn == nullptr) - return fn = function::create(ty, global_value::external, name, this); + if(fn == nullptr){ + fn = function::create(ty, global_value::external, name, this); + } return fn; } diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 7e4e4e5d7..735fad965 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -188,7 +188,26 @@ bool composite_type::index_valid(value *idx) const{ } //===----------------------------------------------------------------------===// -// tile_type class +// struct_type class +//===----------------------------------------------------------------------===// + +struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed) + : composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) { + contained_tys_ = tys; +} + +struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) { + assert(tys.size()); + context_impl* impl = tys[0]->get_context().p_impl.get(); + struct_type *& entry = impl->struct_tys[tys]; + if(!entry) + entry = new struct_type(tys, is_packed); + return entry; +} + + +//===----------------------------------------------------------------------===// +// block_type class //===----------------------------------------------------------------------===// block_type::block_type(type *ty, const block_shapes_t &shapes) diff --git a/lib/ir/value.cc b/lib/ir/value.cc index b970e07d7..251d64479 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -1,5 +1,6 @@ #include #include +#include #include "triton/ir/value.h" #include "triton/ir/instructions.h" @@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){ } void value::add_use(user *arg) { - users_.insert(arg); + users_.push_back(arg); } value::users_t::iterator value::erase_use(user *arg){ - auto it = users_.find(arg); + auto it = std::find(users_.begin(), users_.end(), arg); if(it == users_.end()) return it; return users_.erase(it); diff --git a/python/setup.py b/python/setup.py index 6a04a4e42..9179baa5b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - self.debug = True + # self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/src/triton.cc b/python/src/triton.cc index 22017ebf5..b97044421 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) { py::class_(m, "type") .def("is_ptr", &ir::type::is_pointer_ty) .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("get_int_width", &ir::type::get_integer_bitwidth) + .def("is_floating", &ir::type::is_floating_point_ty) .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) @@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) { .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("is_struct", &ir::type::is_struct_ty) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) @@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) { py::class_(m, "pointer_type") .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); - py::class_(m, "function_type"); + py::class_(m, "function_type") + .def_property_readonly("ret_ty", &ir::function_type::get_return_ty) + .def_property_readonly("arg_tys", [](ir::function_type* self){ + return std::vector(self->params_begin(), self->params_end()); + }); + py::class_(m, "integer_type"); + py::class_(m, "block_type") .def_property_readonly("shape", &ir::block_type::get_shapes) .def_property_readonly("numel", &ir::type::get_tile_num_elements); + + py::class_(m, "struct_type") + .def("get", &ir::struct_type::get, ret::reference) + .def_property_readonly("num_types", &ir::struct_type::get_num_types); + + py::class_(m, "value_constructor") + .def(py::init()) + .def("seal_block", &ir::value_constructor::seal_block) + .def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value) + .def("set_type", &ir::value_constructor::set_type) + .def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference) + .def("get_values", &ir::value_constructor::get_values, ret::reference) + .def("set_values", &ir::value_constructor::set_values); py::class_(m, "module") .def(py::init()) + .def("has_function", &ir::module::has_function) + .def("get_function", &ir::module::get_function, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) - .def("seal_block", &ir::module::seal_block) - .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) - .def("set_type", &ir::module::set_type) - .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) - .def("get_values", &ir::module::get_values, ret::reference) - .def("set_values", &ir::module::set_values) - .def("get_types", &ir::module::get_types, ret::reference) - .def("set_types", &ir::module::set_types) + .def("reset_ret_ty", &ir::module::reset_ret_ty) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; @@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) { .value("not_implemented", eattr::not_implemented); py::class_(m, "attribute") - .def(py::init()); + .def(py::init()) + .def_property_readonly("value", &ir::attribute::get_value); py::class_(m, "function") .def_property_readonly("args", &ir::function::args) .def_property_readonly("attrs", &ir::function::attrs) - .def("add_attr", &ir::function::add_attr); + .def("set_is_kernel", &ir::function::set_is_kernel) + .def("add_attr", &ir::function::add_attr) + .def("has_attr", &ir::function::has_attr) + .def("get_attrs", &ir::function::get_attributes); - py::class_(m, "argument"); + py::class_(m, "argument") + .def_property_readonly("parent", &ir::argument::get_parent, ret::reference) + .def_property_readonly("arg_no", &ir::argument::get_arg_no); py::class_(m, "basic_block") - .def("create", &ir::basic_block::create, ret::reference) + .def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); + py::class_(m, "bb_iterator"); + py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) // getters .def_property_readonly("context", &ir::builder::get_context, ret::reference) // control flow + .def("call", &ir::builder::create_call, ret::reference) + .def("launch", &ir::builder::create_launch, ret::reference) .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) + .def("ret", &ir::builder::create_ret, ret::reference) + .def("get_insert_point", &ir::builder::get_insert_point) + .def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + // struct + .def("insert_value", &ir::builder::create_insert_value, ret::reference) + .def("extract_value", &ir::builder::create_extract_value, ret::reference) // constants .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a49b47585..08e5b721f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -585,7 +585,6 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non # Parse out the type of the 'VALUE' parameter from the Triton IR. triton_ir = pgm.asm['ttir'] - ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) - ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(2) assert ir_value_type == value_type @@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' kernel[(1, )](value, x) else: kernel[(1, )](value, x) +# ------------------------- +# test dynamic parallelism +# ------------------------- + + +@triton.jit +def mult(x, alpha): + tl.store(x + tl.program_id(0), alpha) + + +@triton.jit +def stub(X, alpha, grid_0, grid_1, grid_2): + tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) + + +def test_dyn_par(cond=True, device='cuda'): + n_pids = 10 + # pids = torch.arange(n_pids, device=device) + # alpha = 2.0 + # x_ref = pids * alpha + x_tri = torch.full((10,), fill_value=-1., device=device) + # cond = torch.tensor([cond], device=device) + stub[(1,)](x_tri, 3.14, n_pids, 1, 1) + print(x_tri) + # triton.testing.assert_almost_equal(x_ref, x_tri) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index cb705aaa6..e6102366a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -21,6 +21,41 @@ import triton._C.libtriton.triton as _triton from .tools.disasm import extract +def mangle_ty(type): + if type.is_ptr(): + return 'P' + mangle_ty(type.element) + if type.is_int(): + return 'i' + str(type.get_int_width()) + if type.is_fp8(): + return 'fp8' + if type.is_fp16(): + return 'fp16' + if type.is_bf16(): + return 'bf16' + if type.is_fp32(): + return 'fp32' + if type.is_fp64(): + return 'fp64' + if type.is_void(): + return 'V' + if type.is_block(): + elt = mangle_ty(type.scalar) + shape = '_'.join(map(str, type.shape)) + return f'{elt}S{shape}S' + assert False, "Unsupport type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + class CodeGenerator(ast.NodeVisitor): def get_value(self, name): # search node.id in local scope @@ -36,7 +71,7 @@ class CodeGenerator(ast.NodeVisitor): else: raise ValueError(f'{name} is not defined') if isinstance(ret, triton.language.block): - handle = self.module.get_value(name) + handle = self.value_constructor.get_value(name) return triton.language.block(handle) return ret @@ -44,8 +79,8 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(value, _triton.ir.value): value = triton.language.block(value) if isinstance(value, triton.language.block): - self.module.set_value(name, value.handle) - self.module.set_type(name, value.handle.type) + self.value_constructor.set_value(name, value.handle) + self.value_constructor.set_type(name, value.handle.type) self.lscope[name] = value def is_triton_object(self, value): @@ -58,16 +93,17 @@ class CodeGenerator(ast.NodeVisitor): break return stmts and isinstance(stmt, ast.Return) - def __init__(self, context, prototype, gscope, attributes, constants, kwargs): + def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False): self.builder = _triton.ir.builder(context) - self.module = _triton.ir.module('', self.builder) + self.value_constructor = _triton.ir.value_constructor(self.builder) + self.module = _triton.ir.module('', self.builder) if module is None else module self.prototype = prototype self.gscope = gscope self.lscope = dict() self.attributes = attributes self.constants = constants - self.kwargs = kwargs self.last_node = None + self.is_kernel = is_kernel self.builtins = { 'range': range, 'min': triton.language.minimum, @@ -92,9 +128,17 @@ class CodeGenerator(ast.NodeVisitor): ret = self.visit(node.value) if ret is None: return self.builder.ret_void() - return ret + if isinstance(ret, _triton.ir.value): + ret = self.builder.ret(ret) + return ret + if isinstance(ret, triton.language.block): + ret = ret.handle + if isinstance(ret, triton.language.constexpr): + ret = triton.language.core._to_ir(ret, self.builder) + # TODO: should return tl.block + return self.builder.ret(ret) - def visit_FunctionDef(self, node, inline=False, arg_values=None): + def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): @@ -107,45 +151,44 @@ class CodeGenerator(ast.NodeVisitor): else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) - # store keyword arguments in local scope - self.lscope[kwarg_names] = self.kwargs # initialize function - if inline: - pass - else: - fn = self.module.get_or_insert_function(node.name, self.prototype) - arg_values = [] - idx = 0 - for i, arg_name in enumerate(arg_names): - if i in self.constants: - cst = self.constants[i] - if not isinstance(cst, triton.language.constexpr): - cst = triton.language.constexpr(self.constants[i]) - arg_values.append(cst) - else: - if i in self.attributes: - is_ptr = fn.args[idx].type.is_ptr() - attr = 'aligned' if is_ptr else 'multiple_of' - attr = getattr(_triton.ir.attribute_kind, attr) - attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(idx + 1, attr) - fn.args[idx].name = arg_name - arg_values.append(fn.args[idx]) - idx += 1 + fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants) + fn = self.module.get_or_insert_function(fn_name, self.prototype) + fn.set_is_kernel(self.is_kernel) + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) + arg_values.append(cst) + else: + if i in self.attributes: + is_ptr = fn.args[idx].type.is_ptr() + attr = 'aligned' if is_ptr else 'multiple_of' + attr = getattr(_triton.ir.attribute_kind, attr) + attr = _triton.ir.attribute(attr, self.attributes[i]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(fn.args[idx]) + idx += 1 + insert_pt = self.builder.get_insert_block() + entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) + self.builder.set_insert_block(entry) + self.value_constructor.seal_block(entry) for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) - if inline: - self.visit_compound_statement(node.body) - return self.last_ret - else: - entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self.module.seal_block(entry) - self.builder.set_insert_block(entry) - # visit function body - self.visit_compound_statement(node.body) - # finalize function + # visit function body + has_ret = self.visit_compound_statement(node.body) + # finalize + if not has_ret: self.builder.ret_void() + else: + self.module.reset_ret_ty(fn_name, self.last_ret.type) + # self.module.reset_ret_type(node.name) + self.builder.set_insert_block(insert_pt) def visit_arguments(self, node): arg_names = [] @@ -186,6 +229,12 @@ class CodeGenerator(ast.NodeVisitor): names = [names] if not isinstance(values, tuple): values = [values] + if isinstance(values[0], _triton.ir.value): + struct = values[0] + ty = struct.type + if ty.is_struct(): + values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)] + assert len(values) == len(names) for name, value in zip(names, values): # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): @@ -215,6 +264,17 @@ class CodeGenerator(ast.NodeVisitor): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] + mode = type(args[0]) + # tuple of values -- create a struct + if len(args) > 1 and mode == triton.language.block\ + and all([type(arg) == mode for arg in args]): + args = [arg.handle for arg in args] + tys = [arg.type for arg in args] + struct_ty = _triton.ir.struct_type.get(tys, True) + ret = _triton.ir.undef.get(struct_ty) + for i, arg in enumerate(args): + ret = self.builder.insert_value(ret, arg, i) + return ret return tuple(args) def visit_BinOp(self, node): @@ -254,9 +314,9 @@ class CodeGenerator(ast.NodeVisitor): then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self.module.seal_block(then_bb) + self.value_constructor.seal_block(then_bb) if else_bb: - self.module.seal_block(else_bb) + self.value_constructor.seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -271,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self.module.seal_block(endif_bb) + self.value_constructor.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -350,9 +410,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self.value_constructor.seal_block(stop_bb) + self.value_constructor.seal_block(loop_bb) + self.value_constructor.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -421,9 +481,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self.value_constructor.seal_block(stop_bb) + self.value_constructor.seal_block(loop_bb) + self.value_constructor.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -449,15 +509,62 @@ class CodeGenerator(ast.NodeVisitor): for keyword in node.keywords: kws.update(self.visit(keyword)) args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): - return fn(*args, generator=self, **kws) + from inspect import getcallargs + args = getcallargs(fn.fn, *args, **kws) + args = [args[name] for name in fn.arg_names] + args = [arg if isinstance(arg, triton.language.block) + else triton.language.constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in arg_vals] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + ret_type = _triton.ir.type.get_void(self.builder.context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module) + generator.visit(fn.parse()) + symbol = self.module.get_function(fn_name) + ret = self.builder.call(symbol, arg_vals) + if not ret.type.is_void() and not ret.type.is_struct(): + ret = triton.language.block(ret) + return ret + # built-in function if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: - return fn(*args, _builder=self.builder, **kws) + ret = fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] - return fn(*args, **kws) + ret = fn(*args, **kws) + # special case: dynamic parallelism + # in this case the core primitive returns a proxy + # if isinstance(ret, triton.language.core.LaunchProxy): + # ret_type = _triton.ir.type.get_void(self.builder.context) + # arg_tys = [x.type for x in ret.args] + # prototype = _triton.ir.type.make_function(ret_type, arg_tys) + # gscope = sys.modules[ret.fn.fn.__module__].__dict__ + # constants = ret.constants + # fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants) + # # TODO: clean-up attributes handling in function + # if not self.module.has_function(fn_name): + # attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \ + # if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) } + # generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True) + # generator.visit(ret.fn.parse()) + # symbol = self.module.get_function(fn_name) + # # TODO: should ret.args not include any constants ? + # ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps) + return ret + # return fn(*args, **kws) def visit_Constant(self, node): return triton.language.constexpr(node.value) @@ -669,6 +776,7 @@ class Kernel: def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes attributes = dict() for i, arg in enumerate(wargs): @@ -881,7 +989,7 @@ class JITFunction: cache_hook = None - def __init__(self, fn, version=None, do_not_specialize=None): + def __init__(self, fn, version=None, inline=True, do_not_specialize=None): # information of wrapped function self.fn = fn self.module = fn.__module__ @@ -890,6 +998,7 @@ class JITFunction: self.arg_defaults = [v.default for v in signature.parameters.values()] self.version = version + self.inline = inline self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize @@ -904,6 +1013,8 @@ class JITFunction: # annotations self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} self.__annotations__ = fn.__annotations__ + # constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] # forward docs self.__doc__ = fn.__doc__ self.__name__ = fn.__name__ @@ -930,31 +1041,8 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator, **kwargs): - try: - from inspect import getcallargs - arg_values = getcallargs(self.fn, *args, **kwargs) - arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.block) - else triton.language.constexpr(arg) for arg in arg_values] - - gscope = generator.gscope.copy() - lscope = generator.lscope.copy() - values = generator.module.get_values().copy() - types = generator.module.get_types().copy() - generator.gscope = sys.modules[self.fn.__module__].__dict__ - generator.lscope = dict() - ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) - generator.gscope = gscope - generator.lscope = lscope - generator.module.set_values(values) - generator.module.set_types(types) - return ret - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.") # - when `.src` attribute is set, cache path needs # to be reinitialized @@ -1039,7 +1127,7 @@ class JITFunction: # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) try: generator.visit(self.parse()) except Exception as e: @@ -1199,9 +1287,21 @@ def jit(*args, **kwargs): return JITFunction(fn, **kwargs) return decorator +###### + +# class ForwardDeclaration: + +# def __init__(self, name, ret_ty, arg_tys) -> None: +# self.name = name +# self.ret_ty = ret_ty +# self.arg_tys = arg_tys + +# def forward_declare(name, ret_ty, arg_tys): +# return ForwardDeclaration(name, ret_ty, arg_tys) ###### + def cdiv(x, y): return (x + y - 1) // y diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 0312d8146..cad4edfe4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -888,7 +888,7 @@ def sigmoid(x): @triton.jit @_add_math_1arg_docstr("softmax") -def softmax(x, ieee_rounding=False): +def softmax(x, ieee_rounding: constexpr = False): z = x - triton.language.max(x, 0) num = triton.language.exp(z) den = triton.language.sum(num, 0) @@ -942,3 +942,26 @@ def swizzle2d(i, j, size_i, size_j, size_g): @triton.jit def zeros_like(input): return zeros(input.shape, input.dtype) +# ----------------------- +# Dynamic Parallelism +# ----------------------- + + +class LaunchProxy: + + def __init__(self, fn, args, constants, grid, num_warps) -> None: + self.args = args + self.grid = grid + self.constants = constants + self.num_warps = num_warps + self.fn = fn + + +@builtin +def launch(fn, args, grid, num_warps=None, _builder=None): + constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)} + args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)] + grid = [_to_ir(x, builder=_builder) for x in grid] + if num_warps is None: + num_warps = _to_ir(4, builder=_builder) + return LaunchProxy(fn, args, constants, grid, num_warps) From 9f08ecd68494c007ca198b2bde6464052959cf06 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 6 Apr 2022 16:13:53 -0700 Subject: [PATCH 086/215] [FRONTEND] Semantic analysis refactor (#491) Moved dispatch.cc to semantic.py (@ptillet) Integer signedness analysis was moved from C++ to python (@daadaada) Cleaner frontend types (@daadaada) Moved SSA construction to a separate object (@ptillet) Co-authored-by: Yan Da --- include/triton/ir/builder.h | 34 +- include/triton/ir/context_impl.h | 1 - include/triton/ir/dispatch.h | 117 --- include/triton/ir/module.h | 23 +- include/triton/ir/type.h | 20 +- lib/ir/builder.cc | 63 +- lib/ir/context.cc | 18 +- lib/ir/dispatch.cc | 895 ----------------- lib/ir/instructions.cc | 4 +- lib/ir/module.cc | 144 --- lib/ir/type.cc | 14 - python/src/triton.cc | 408 +++++--- python/test/regression/test_performance.py | 2 +- python/test/unit/language/test_core.py | 38 +- python/test/unit/runtime/test_cache.py | 28 + python/triton/__init__.py | 3 +- python/triton/code_gen.py | 403 +++++--- python/triton/language/core.py | 652 ++++++++---- python/triton/language/semantic.py | 1052 ++++++++++++++++++++ 19 files changed, 2174 insertions(+), 1745 deletions(-) delete mode 100644 include/triton/ir/dispatch.h delete mode 100644 lib/ir/dispatch.cc create mode 100644 python/triton/language/semantic.py diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index ff8447124..0cb622679 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -41,10 +41,8 @@ public: iterator get_insert_point() { return insert_point_;} // Constants value *get_int1(bool val); - value *get_int32(int32_t val); - value *get_int64(int64_t val); - value *get_uint32(uint32_t val); - value *get_uint64(uint64_t val); + value *get_int32(uint32_t val); + value *get_int64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -55,11 +53,9 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); - type *get_uint8_ty(); - type *get_uint16_ty(); - type *get_uint32_ty(); - type *get_uint64_ty(); + type *get_fp8_ty(); type *get_half_ty(); + type *get_bf16_ty(); type *get_float_ty(); type *get_double_ty(); // Insert @@ -78,7 +74,9 @@ public: value* create_ret_void(); value* create_ret(value *ret); // Cast instructions + value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); + value* create_int_to_ptr(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty); @@ -100,11 +98,11 @@ public: value *create_frem(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sdiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); @@ -155,11 +153,25 @@ public: value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); + // Atomic instruction + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); + value *create_atomic_max(value *ptr, value *val, value *msk); + value *create_atomic_umax(value *ptr, value *val, value *msk); + value *create_atomic_min(value *ptr, value *val, value *msk); + value *create_atomic_umin(value *ptr, value *val, value *msk); + value *create_atomic_fadd(value *ptr, value *val, value *msk); + value *create_atomic_add(value *ptr, value *val, value *msk); + value *create_atomic_and(value *ptr, value *val, value *msk); + value *create_atomic_or(value *ptr, value *val, value *msk); + value *create_atomic_xor(value *ptr, value *val, value *msk); + value *create_atomic_xchg(value *ptr, value *val, value *msk); + // Utilities + value *create_clock(); + value *create_globaltimer(); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); value *create_sin(value* arg); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 619ae4c87..f2d956a30 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -26,7 +26,6 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; - integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, std::unique_ptr> ptr_tys; // Block types diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h deleted file mode 100644 index c7f23779c..000000000 --- a/include/triton/ir/dispatch.h +++ /dev/null @@ -1,117 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_DISPATCH_H_ -#define _TRITON_IR_DISPATCH_H_ - -#include "triton/ir/builder.h" -#include - -namespace triton{ -namespace ir{ - - -/*---------------------------------------------- - higher level functions that follow the likely - semantics of most expected frontends - ----------------------------------------------*/ - -struct semantic_error: public std::runtime_error { - semantic_error(const std::string& msg): - std::runtime_error(msg) { } -}; - -struct dispatch{ - typedef ir::type::block_shapes_t shape_t; - - - // programming model - static ir::value *program_id(int axis, ir::builder *builder); - static ir::value *num_programs(int axis, ir::builder *builder); - - // binary operators - static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); - static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); - - // unary operators - static ir::value *plus(ir::value *input, ir::builder *builder); - static ir::value *minus(ir::value *input, ir::builder *builder); - static ir::value *invert(ir::value *input, ir::builder *builder); - - // comparison operators - static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); - - // block creation - static ir::value* arange(int start, int end, ir::builder *builder); - static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); - - - // casting ops - static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); - static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); - static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); - static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); - static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); - static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); - - // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, - const std::string& eviction_policy, int is_volatile, ir::builder *builder); - static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); - static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); - static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - - // linear algebra - static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); - - // indexing - static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); - - // reduction - static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); - - // math - static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); - static ir::value *exp(ir::value *x, ir::builder *builder); - static ir::value *log(ir::value *x, ir::builder *builder); - static ir::value *cos(ir::value *x, ir::builder *builder); - static ir::value *sin(ir::value *x, ir::builder *builder); - static ir::value *sqrt(ir::value *x, ir::builder *builder); - - // utilities - static ir::value *globaltimer(ir::builder *builder); - static ir::value *clock(ir::builder *builder); - - // internal (debug/optimization) - static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); - static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); - static ir::value *debug_barrier(ir::builder *builder); -}; - -} -} - -#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index f8f033eb7..aa279af98 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -36,7 +36,6 @@ class alloc_const; class value_constructor { typedef std::pair val_key_t; - typedef std::pair md_pair_t; private: phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); @@ -57,7 +56,6 @@ public: // Seal block -- no more predecessors will be added void seal_block(basic_block *block); // Metadata - void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } private: ir::builder& builder_; @@ -66,13 +64,13 @@ private: std::set sealed_blocks_; std::map> incomplete_phis_; std::map current_phi_; - std::map metadatas_; }; /* Module */ class module { typedef std::pair val_key_t; + typedef std::pair md_pair_t; friend class function; public: @@ -83,13 +81,10 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, builder& builder); - builder& get_builder(); - // Setters - void set_continue_fn(std::function fn); - // Getters - const std::string& get_name(); - std::function get_continue_fn(); + module(const std::string &name, builder &builder): name_(name), builder_(builder) {} + builder &get_builder() { return builder_; }; + const std::string& get_name() { return name_; }; + // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } @@ -114,17 +109,19 @@ public: // Register global void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } const std::map& globals() const { return globals_; } - // + // Metadata void print(std::ostream &os); + void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } + const std::map &get_metadatas() const { return metadatas_; } private: std::string name_; - builder& builder_; + builder &builder_; functions_list_t functions_; symbols_map_t symbols_; - std::function continue_fn_; std::vector allocs_; std::map globals_; + std::map metadatas_; }; } diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index d7919b4c8..16a81cb5f 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -16,8 +16,6 @@ class value; class integer_type; class constant_int; -enum class signedness { SIGNED, UNSIGNED }; - /* Type */ class type { public: @@ -61,8 +59,6 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; - signedness get_integer_signedness() const; - bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -87,9 +83,6 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_integer_ty(unsigned bitwidth, signedness sn) { - return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; - } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -118,10 +111,6 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); - static integer_type *get_uint8_ty(context &ctx); - static integer_type *get_uint16_ty(context &ctx); - static integer_type *get_uint32_ty(context &ctx); - static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -148,7 +137,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -171,21 +160,18 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth, signedness sn) - : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } + integer_type(context &ctx, unsigned bitwidth) + : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} public: // accessors unsigned get_bitwidth() const { return bitwidth_; } - signedness get_signedness() const { return signedness_; } - // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; - signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 58174aa7a..d79e5d9d1 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -48,18 +49,12 @@ void builder::set_insert_point(basic_block *block){ value *builder::get_int1(bool val) { return constant_int::get(type::get_int1_ty(ctx_), val); } -value *builder::get_int32(int32_t val) +value *builder::get_int32(uint32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} -value *builder::get_uint32(uint32_t val) -{ return constant_int::get(type::get_uint32_ty(ctx_), val);} - -value *builder::get_int64(int64_t val) +value *builder::get_int64(uint64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} -value *builder::get_uint64(uint64_t val) -{ return constant_int::get(type::get_uint64_ty(ctx_), val);} - value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -90,21 +85,15 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } -type *builder::get_uint8_ty() -{ return type::get_uint8_ty(ctx_); } - -type *builder::get_uint16_ty() -{ return type::get_uint16_ty(ctx_); } - -type *builder::get_uint32_ty() -{ return type::get_uint32_ty(ctx_); } - -type *builder::get_uint64_ty() -{ return type::get_uint64_ty(ctx_); } +type *builder::get_fp8_ty() +{ return type::get_fp8_ty(ctx_); } type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } +type *builder::get_bf16_ty() +{ return type::get_bf16_ty(ctx_); } + type *builder::get_float_ty() { return type::get_fp32_ty(ctx_); } @@ -140,6 +129,8 @@ value *builder::create_ret(value* val) { return create_cast(OPCODE, src, dst_ty);\ } +DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) +DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) @@ -358,6 +349,37 @@ value *builder::create_downcast(value *arg) { return insert(downcast_inst::create(arg)); } +// + +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); +} + +#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ + value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ + return create_atomic_rmw(OPCODE, ptr, val, mask);\ + } + +DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) +DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) +DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) +DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) +DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) +DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) +DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) +DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) +DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) +DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) + +// Utilities +value *builder::create_clock() { + return insert(clock_inst::create(ctx_)); +} + +value *builder::create_globaltimer() { + return insert(globaltimer_inst::create(ctx_)); +} + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -374,9 +396,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 90b109b9b..0fc65ddc2 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,18 +19,12 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1, signedness::SIGNED), - int8_ty(ctx, 8, signedness::SIGNED), - int16_ty(ctx, 16, signedness::SIGNED), - int32_ty(ctx, 32, signedness::SIGNED), - int64_ty(ctx, 64, signedness::SIGNED), - int128_ty(ctx, 128, signedness::SIGNED), - uint8_ty(ctx, 8, signedness::UNSIGNED), - uint16_ty(ctx, 16, signedness::UNSIGNED), - uint32_ty(ctx, 32, signedness::UNSIGNED), - uint64_ty(ctx, 64, signedness::UNSIGNED){ - -} + int1_ty(ctx, 1), + int8_ty(ctx, 8), + int16_ty(ctx, 16), + int32_ty(ctx, 32), + int64_ty(ctx, 64), + int128_ty(ctx, 128) {} //===----------------------------------------------------------------------===// // context diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc deleted file mode 100644 index fc0252dbf..000000000 --- a/lib/ir/dispatch.cc +++ /dev/null @@ -1,895 +0,0 @@ -#include "triton/ir/dispatch.h" - -namespace triton { -namespace ir { - - -[[ noreturn ]] void throw_unreachable(std::string key) { - throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " - "This is likely a bug on our side."); -} - -//===----------------------------------------------------------------------===// -// Programming Model -//===----------------------------------------------------------------------===// - -ir::value *dispatch::program_id(int axis, ir::builder *builder) { - return builder->create_get_program_id(axis); -} - -ir::value *dispatch::num_programs(int axis, ir::builder *builder) { - return builder->create_get_num_programs(axis); -} - -//===----------------------------------------------------------------------===// -// Implicit Casting Utilities -//===----------------------------------------------------------------------===// - -ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ - int a_rank = a_ty->get_integer_bitwidth(); - int b_rank = b_ty->get_integer_bitwidth(); - auto a_sn = a_ty->get_integer_signedness(); - auto b_sn = b_ty->get_integer_signedness(); - // Rules for signedness taken from "Usual arithmetic conversions" on - // https://en.cppreference.com/w/c/language/conversion. - if (a_sn == b_sn) { - return a_rank > b_rank ? a_ty : b_ty; - } else if (a_sn == signedness::UNSIGNED) { - return a_rank >= b_rank ? a_ty : b_ty; - } else if (b_sn == signedness::UNSIGNED) { - return b_rank >= a_rank ? b_ty : a_ty; - } else { - throw_unreachable("integer_promote"); - } -} - -enum class DivOrMod { NO, YES }; - -ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { - context &ctx = a_ty->get_context(); - // 1) if one operand is double, the other is implicitly - // converted to double - if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) - return type::get_fp64_ty(ctx); - // 2) if one operand is float, the other is implicitly - // converted to float - if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) - return type::get_fp32_ty(ctx); - // 3 ) if one operand is half, the other is implicitly converted to half - // unless we're doing / or %, which do not exist natively in PTX for fp16. - if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { - if (div_or_mod == DivOrMod::YES) { - return type::get_fp32_ty(ctx); - } else { - return type::get_fp16_ty(ctx); - } - } - if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) - throw_unreachable("computation_type"); - // 4 ) both operands are integer and undergo - // integer promotion - if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { - throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - return integer_promote(a_ty, b_ty); -} - -//===----------------------------------------------------------------------===// -// Binary Operators -//===----------------------------------------------------------------------===// - -void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { - throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); -} - -void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ - - if(type_a->is_pointer_ty()){ - if(!allow_ptr_a) - throw_incompatible_types(type_a, type_b); - // T* + U* with T != U - if(type_b->is_pointer_ty() && (type_a != type_b)) - throw_incompatible_types(type_a, type_b); - // T* + float - if(type_b->is_floating_point_ty()) - throw_incompatible_types(type_a, type_b); - } -} - -void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, - bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, - bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { - // implicit broadcasting - std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); - // implicit typecasting - ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); - ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); - check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); - check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); - if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { - ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); - lhs = dispatch::cast(lhs, ret_sca_ty, builder); - rhs = dispatch::cast(rhs, ret_sca_ty, builder); - } -} - -ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, true); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // offset + ptr - // ptr + offset - if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) - std::swap(input, other); - if (input_scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {other}); - // float + float - else if (input_scalar_ty->is_floating_point_ty()) - return builder->create_fadd(input, other); - // int + int - else if (input_scalar_ty->is_integer_ty()) - return builder->create_add(input, other); - throw_unreachable("add"); -} - -ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, false); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // ptr - offset - if (scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {dispatch::minus(other, builder)}); - // float + float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fsub(input, other); - // int + int - else if (scalar_ty->is_integer_ty()) - return builder->create_sub(input, other); - throw_unreachable("sub"); -} - -ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float * float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fmul(input, other); - // int * int - else if (scalar_ty->is_integer_ty()) - return builder->create_mul(input, other); - throw_unreachable("mul"); -} - -ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float / int - if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) - other = cast(other, input_scalar_ty, builder); - // int / float - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) - input = cast(input, other_scalar_ty, builder); - // int / int (cast to float32) - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - input = cast(input, builder->get_float_ty(), builder); - other = cast(other, builder->get_float_ty(), builder); - } - // float / float (cast to highest exponent type) - else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ - if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) - other = cast(other, input_scalar_ty, builder); - else - input = cast(input, other_scalar_ty, builder); - } - // unreachable - else - throw_unreachable("div"); - return builder->create_fdiv(input, other); -} - -ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); - input = dispatch::cast(input, ret_ty, builder); - other = dispatch::cast(other, ret_ty, builder); - if (ret_ty->is_integer_signed()) { - return builder->create_sdiv(input, other); - } else { - return builder->create_udiv(input, other); - } - } - throw_unreachable("floordiv"); -} - -ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) - throw semantic_error("both operands of fdiv must have floating point scalar type"); - binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); - ir::value* ret = builder->create_fdiv(input, other); - if(ir::binary_operator* binop = dynamic_cast(ret)) - binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); - return ret; -} - -ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float % int - if (scalar_ty->is_floating_point_ty()) - return builder->create_frem(input, other); - // int % int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { - throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - if (scalar_ty->is_integer_signed()) { - return builder->create_srem(input, other); - } else { - return builder->create_urem(input, other); - } - } - throw_unreachable("mod"); -} - - -void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, false); - ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); - ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); - if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) - throw_incompatible_types(input_sca_ty, other_sca_ty); - ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); - if (ret_sca_ty != input_sca_ty) - input = dispatch::cast(input, ret_sca_ty, builder); - if (ret_sca_ty != other_sca_ty) - other = dispatch::cast(other, ret_sca_ty, builder); -} - -ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_and(input, other); -} - -ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_or(input, other); -} - - -ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_xor(input, other); -} - - -ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_lshr(input, other); -} - - -ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_shl(input, other); -} - -//===----------------------------------------------------------------------===// -// Unary Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::plus(ir::value *input, ir::builder *) { - return input; -} - -ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty()) - throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); - ir::value *_0 = ir::constant::get_null_value(input_sca_ty); - return dispatch::sub(_0, input, builder); -} - -ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) - throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); - ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); - return dispatch::xor_(input, _1, builder); -} - - -//===----------------------------------------------------------------------===// -// Comparison Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float > float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGT(input, other); - // int > int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGT(input, other); - } else { - return builder->create_icmpUGT(input, other); - } - } - throw_unreachable("greater_than"); -} - -ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float >= float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGE(input, other); - // int >= int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGE(input, other); - } else { - return builder->create_icmpUGE(input, other); - } - } - throw_unreachable("greater_equal"); -} - -ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLT(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLT(input, other); - } else { - return builder->create_icmpULT(input, other); - } - } - throw_unreachable("less_than"); -} - -ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLE(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLE(input, other); - } else { - return builder->create_icmpULE(input, other); - } - } - throw_unreachable("less_equal"); -} - -ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOEQ(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpEQ(input, other); - throw_unreachable("equal"); -} - -ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpUNE(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpNE(input, other); - throw_unreachable("equal"); -} - -//===----------------------------------------------------------------------===// -// Block Creation -//===----------------------------------------------------------------------===// - -ir::value* dispatch::arange(int start, int end, ir::builder *builder) { - return builder->get_range(start, end); -} - -ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { - ir::value *_0 = ir::constant::get_null_value(dtype); - return builder->create_splat(_0, shape); -} - -//===----------------------------------------------------------------------===// -// Shape Manipulation -//===----------------------------------------------------------------------===// - - -ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { - unsigned numel = 1; - for(unsigned s: dst_shape) numel *= s; - if(input->get_type()->get_tile_num_elements() != numel) - throw semantic_error("cannot reshape block of different shape"); - return builder->create_reshape(input, dst_shape); -} - -ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { - return builder->create_cat(lhs, rhs); -} - -ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { - if (!input->get_type()->is_block_ty()) - return builder->create_splat(input, shape); - auto src_shape = input->get_type()->get_block_shapes(); - if (src_shape.size() != shape.size()) - throw std::runtime_error("Cannot broadcast"); - if(shape == src_shape) - return input; - return builder->create_broadcast(input, shape); -} - -std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - - // make_shape_compatible(block, scalar) - if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) - rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); - // make_shape_compatible(scalar, block) - else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) - lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); - // make_shape_compatible(block, block) - else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { - auto lhs_shape = lhs_ty->get_block_shapes(); - auto rhs_shape = rhs_ty->get_block_shapes(); - if (lhs_shape.size() != rhs_shape.size()) - throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); - ir::type::block_shapes_t ret_shape; - for (size_t i = 0; i < lhs_shape.size(); ++i) { - unsigned left = lhs_shape[i]; - unsigned right = rhs_shape[i]; - if (left == 1) - ret_shape.push_back(right); - else if (right == 1) - ret_shape.push_back(left); - else if (left == right) - ret_shape.push_back(left); - else - throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + - ": " + std::to_string(left) + " and " + std::to_string(right)); - } - if (lhs_shape != ret_shape) - lhs = builder->create_broadcast(lhs, ret_shape); - if (rhs_shape != ret_shape) - rhs = builder->create_broadcast(rhs, ret_shape); - } - return std::make_tuple(lhs, rhs); -} - -ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty()) - return cast(input, dst_ty, builder); - // Bitcast - int src_bits = src_sca_ty->get_primitive_size_in_bits(); - int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); - if( src_bits!= dst_bits) - throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + - "to data-type of size " + std::to_string(dst_bits)); - return builder->create_cast(ir::BitCast, input, dst_ty); -} - -ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - // - if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || - (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ - return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); - } - // FP Truncation - bool truncate_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); - if (truncate_fp) - return builder->create_fp_trunc(input, dst_ty); - // FP Extension - bool ext_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); - if (ext_fp) - return builder->create_fp_ext(input, dst_ty); - // Int cast - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || - src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { - bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); - return builder->create_int_cast(input, dst_ty, sign_extend); - } - // Float -> Int - if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ - if(dst_sca_ty->is_bool_ty()) - return builder->create_fp_to_ui(input, dst_ty); - else - return builder->create_fp_to_si(input, dst_ty); - } - // int -> Float - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ - if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) - return builder->create_ui_to_fp(input, dst_ty); - else - return builder->create_si_to_fp(input, dst_ty); - } - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ - int bitwidth = dst_sca_ty->get_integer_bitwidth(); - if(bitwidth == 64) - return builder->create_cast(ir::PtrToInt, input, dst_ty); - if(bitwidth == 1) - return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), - builder->get_int64(0), - builder); - } - if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::IntToPtr, input, dst_ty); - // Ptr -> Ptr - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::BitCast, input, dst_ty); - // * -> Bool - if (dst_sca_ty->is_bool_ty()) { - if (src_sca_ty->is_pointer_ty()) - input = cast(input, builder->get_int64_ty(), builder); - ir::value *other = builder->get_int64(0); - if (src_ty->is_bool_ty()) - other = builder->create_splat(other, src_ty->get_block_shapes()); - return builder->create_icmpNE(input, other); - } - throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); -} - -//===----------------------------------------------------------------------===// -// Memory Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - if(other) - other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); - } - if(other) - other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cache modifier - load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default - if (!cache_modifier.empty()) { - if (cache_modifier == ".ca") - cache = load_inst::CA; - else if (cache_modifier == ".cg") - cache = load_inst::CG; - else - throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); - } - // eviction policy - load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default - if(!eviction_policy.empty()){ - if (eviction_policy == "evict_last") - eviction = load_inst::EVICT_LAST; - else if(eviction_policy == "evict_first") - eviction = load_inst::EVICT_FIRST; - else - throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); - } - - - if (!mask && !other) - return builder->create_load(ptr, cache, eviction, is_volatile); - if (!mask) - throw std::runtime_error("`other` cannot be provided without `mask`"); - auto shape = ptr->get_type()->get_block_shapes(); - if(!other){ - other = ir::undef_value::get(elt_ty); - if(ptr->get_type()->is_block_ty()) - other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); - } - return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); -} - -ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()) - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cast to target data-type - val = dispatch::cast(val, elt_ty, builder); - if (!mask) - return builder->create_store(ptr, val); - if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) - throw semantic_error("Mask must have boolean scalar type"); - return builder->create_masked_store(ptr, val, mask); -} - -ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ - return builder->create_atomic_cas(ptr, cmp, val); -} - -void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask){ - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - } - if(val){ - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - } - } - val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - if(!mask){ - mask = builder->get_int1(true); - if(ptr->get_type()->is_block_ty()) - mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); - } -} - -ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); - } - } - // for float - // return atomic_smax(i_ptr, i_val) if val >= 0 - // return atomic_umin(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_min for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); - } - } - // for float - // return atomic_smin(i_ptr, i_val) if val >= 0 - // return atomic_umax(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; - return builder->create_atomic_rmw(op, ptr, val, mask); -} - -ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); -} - -ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); -} - -ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); -} - -ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); -} - -//===----------------------------------------------------------------------===// -// Linear Algebra -//===----------------------------------------------------------------------===// - -ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { - ir::value *_0 = nullptr; - if (lhs->get_type()->is_int_or_tileint_ty()) - _0 = builder->get_int32(0); - else - _0 = builder->get_float32(0); - unsigned M = lhs->get_type()->get_block_shapes()[0]; - unsigned N = rhs->get_type()->get_block_shapes()[1]; - _0 = builder->create_splat(_0, {M, N}); - bool _allow_tf32 = allow_tf32->get_value() != 0; - return builder->create_dot(lhs, rhs, _0, _allow_tf32); -} - - -//===----------------------------------------------------------------------===// -// Indexing -//===----------------------------------------------------------------------===// - -ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ - condition = dispatch::cast(condition, builder->get_int1_ty(), builder); - if(condition->get_type()->is_block_ty()){ - x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); - y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); - } - ir::type* x_ty = x->get_type()->get_scalar_ty(); - ir::type* y_ty = y->get_type()->get_scalar_ty(); - ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); - x = dispatch::cast(x, ty, builder); - y = dispatch::cast(y, ty, builder); - return builder->create_select(condition, x, y); -} - - -//===----------------------------------------------------------------------===// -// Reductions -//===----------------------------------------------------------------------===// - -ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, - ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // input is extended to 32-bits if necessary - // this increases numerical accuracy and can be done pretty much for free - // on GPUs - if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) - input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); - if (scalar_ty->is_floating_point_ty()) - return builder->create_reduce(input, FLOAT_OP, axis); - else if (scalar_ty->is_integer_ty()) - return builder->create_reduce(input, INT_OP, axis); - throw_unreachable(name); -} - -ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); -} - -ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); -} - -ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); -} - -ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - if (!scalar_ty->is_integer_ty()) - throw semantic_error("xor_sum only supported for integers"); - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); -} - - -//===----------------------------------------------------------------------===// -// Math -//===----------------------------------------------------------------------===// - -ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { - binary_op_type_checking(x, y, builder); - return builder->insert(umulhi_inst::create(x, y)); -} - -ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { - return builder->create_exp(x); -} - -ir::value *dispatch::log(ir::value *x, ir::builder *builder) { - return builder->create_log(x); -} - -ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { - return builder->create_cos(x); -} - -ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { - return builder->create_sin(x); -} - -ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { - return builder->create_sqrt(x); -} - -// - -ir::value *dispatch::globaltimer(ir::builder *builder) { - return builder->insert(globaltimer_inst::create(builder->get_context())); -} - -ir::value *dispatch::clock(ir::builder *builder) { - return builder->insert(clock_inst::create(builder->get_context())); - -} - -//===----------------------------------------------------------------------===// -// Control FLow -//===----------------------------------------------------------------------===// -// - -ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("multiple_of"); - i->set_metadata(ir::metadata::multiple_of, value); - return i; -} - -ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("max_contiguous"); - i->set_metadata(ir::metadata::max_contiguous, value); - return i; -} - -ir::value *dispatch::debug_barrier(ir::builder *builder) { - return builder->create_barrier(); -} - - -} -} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 1bcbfa9ff..325976504 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -377,8 +377,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); + (arg_bits > dst_bits ? cast_op_t::Trunc : + (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); return create(op, arg, ty, name, next); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 7df196c8f..d86b60085 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,154 +9,10 @@ namespace triton{ namespace ir{ -/* */ -value_constructor::value_constructor(ir::builder& builder): builder_(builder){ - sealed_blocks_.insert(nullptr); -} - -void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ - values_[val_key_t{name, block}] = value; - auto it = metadatas_.find(name); - if(auto *x = dynamic_cast(value)) - if(it != metadatas_.end()){ - x->set_metadata(it->second.first, it->second.second); - } -// value->set_name(name); -} - -void value_constructor::set_value(const std::string& name, ir::value *value){ - return set_value(name, builder_.get_insert_block(), value); -} - -ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ - basic_block::iterator insert = block->get_first_non_phi(); - if(insert != block->end()){ - builder_.set_insert_point(insert); - } - ir::phi_node *res = builder_.create_phi(ty, num_values); - if(insert != block->end()) - builder_.set_insert_point(block); - return res; -} - -ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){ - // find non-self references - std::set non_self_ref; - std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), - [phi](ir::value* op){ return op != phi && op; }); - // non-trivial - if(non_self_ref.size() != 1) - return phi; - // unique value or self-reference - ir::value *same = *non_self_ref.begin(); - assert(same != nullptr); - phi->replace_all_uses_with(same); - phi->erase_from_parent(); - std::vector users = phi->get_users(); - for(ir::user* u: users) - if(auto *uphi = dynamic_cast(u)) - if(uphi != phi) - try_remove_trivial_phis(uphi); - return same; -} - - -ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){ - // already initialized - if(phi->get_num_operands()) - return phi; - ir::basic_block *block = phi->get_parent(); - for(ir::basic_block *pred: block->get_predecessors()){ - ir::value *value = get_value(name, pred); - phi->add_incoming(value, pred); - } - return phi; -} - -ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) { - ir::value *result; - auto preds = block->get_predecessors(); - ir::type *ty = types_.at(name); - if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){ - incomplete_phis_[block][name] = make_phi(ty, 1, block); - result = (ir::value*)incomplete_phis_[block][name]; - } - else if(preds.size() <= 1){ - bool has_pred = preds.size(); - result = get_value(name, has_pred?preds.front():nullptr); - } - else{ - ir::phi_node* phi = make_phi(ty, 1, block); - set_value(name, block, phi); - result = add_phi_operands(name, phi); - if(auto *phi = dynamic_cast(result)) - result = try_remove_trivial_phis(phi); - } - if(auto *phi = dynamic_cast(result)){ - result = try_remove_trivial_phis(phi); - } - set_value(name, block, result); - return result; -} - -ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) { - ir::basic_block* save_block = builder_.get_insert_block(); - ir::basic_block::iterator save_pt = builder_.get_insert_point(); - val_key_t key(name, block); -// std::cout << values_.size() << std::endl; -// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl; - if(values_.find(key) != values_.end()){ - return values_.at(key); - } - ir::value *result = get_value_recursive(name, block); - builder_.set_insert_point(save_block); - if(save_pt != save_block->end()) - builder_.set_insert_point(save_pt); - return result; -} - -ir::value *value_constructor::get_value(const std::string& name) { - return get_value(name, builder_.get_insert_block()); -} - -void value_constructor::seal_block(ir::basic_block *block){ - for(auto &x: incomplete_phis_[block]){ - add_phi_operands(x.first, x.second); - if(get_value(x.first) == x.second) - set_value(x.first, try_remove_trivial_phis(x.second)); - } - sealed_blocks_.insert(block); - incomplete_phis_[block].clear(); -} - - - -/* Module */ - -module::module(const std::string &name, builder &builder) - : name_(name), builder_(builder) { -} - void module::reset_ret_ty(const std::string& name, type* ty) { get_function(name)->get_fn_type()->reset_ret_ty(ty); } -ir::builder& module::get_builder() { - return builder_; -} - -void module::set_continue_fn(std::function fn) { - continue_fn_ = fn; -} - -std::function module::get_continue_fn() { - return continue_fn_; -} - -const std::string& module::get_name() { - return name_; -} - /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 735fad965..43761f482 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,16 +36,6 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } -signedness type::get_integer_signedness() const -{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } - -bool type::is_integer_signed() const { - if (id_ != IntegerTyID) { - throw std::logic_error("type is " + repr() + ", not integer"); - } - return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; -} - unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -145,10 +135,6 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } -integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } -integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } -integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } -integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index b97044421..a1cf7e54e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,7 +3,6 @@ #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" -#include "triton/ir/dispatch.h" #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -12,10 +11,12 @@ #include #include #include +#include #include #include "Python.h" #include #include +#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -541,87 +542,6 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); } -/*****************************************************************************/ -/* User-facing language features */ -/*****************************************************************************/ - -void init_triton_frontend(py::module &&m) { - using ret = py::return_value_policy; - - // programming model - m.def("program_id", &ir::dispatch::program_id, ret::reference); - m.def("num_programs", &ir::dispatch::num_programs, ret::reference); - // binary - m.def("add", &ir::dispatch::add, ret::reference); - m.def("sub", &ir::dispatch::sub, ret::reference); - m.def("mul", &ir::dispatch::mul, ret::reference); - m.def("truediv", &ir::dispatch::truediv, ret::reference); - m.def("floordiv", &ir::dispatch::floordiv, ret::reference); - m.def("fdiv", &ir::dispatch::fdiv, ret::reference); - m.def("mod", &ir::dispatch::mod, ret::reference); - m.def("and_", &ir::dispatch::and_, ret::reference); - m.def("or_", &ir::dispatch::or_, ret::reference); - m.def("xor_", &ir::dispatch::xor_, ret::reference); - m.def("lshr", &ir::dispatch::lshr, ret::reference); - m.def("shl", &ir::dispatch::shl, ret::reference); - // unary - m.def("plus", &ir::dispatch::plus, ret::reference); - m.def("minus", &ir::dispatch::minus, ret::reference); - m.def("invert", &ir::dispatch::invert, ret::reference); - // comparison - m.def("greater_than", &ir::dispatch::greater_than, ret::reference); - m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); - m.def("less_than", &ir::dispatch::less_than, ret::reference); - m.def("less_equal", &ir::dispatch::less_equal, ret::reference); - m.def("equal", &ir::dispatch::equal, ret::reference); - m.def("not_equal", &ir::dispatch::not_equal, ret::reference); - // block creation - m.def("arange", &ir::dispatch::arange, ret::reference); - m.def("zeros", &ir::dispatch::zeros, ret::reference); - // type manipuatation - m.def("cat", &ir::dispatch::cat, ret::reference); - m.def("reshape", &ir::dispatch::reshape, ret::reference); - typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); - typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); - m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("bitcast", &ir::dispatch::bitcast, ret::reference); - m.def("cast", &ir::dispatch::cast, ret::reference); - // memory - m.def("load", &ir::dispatch::load, ret::reference); - m.def("store", &ir::dispatch::store, ret::reference); - m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); - m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); - m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); - m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); - m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); - m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); - m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); - m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); - // linear algebra - m.def("dot", &ir::dispatch::dot, ret::reference); - // indexing - m.def("where", &ir::dispatch::where, ret::reference); - // reduction - m.def("min", &ir::dispatch::min, ret::reference); - m.def("max", &ir::dispatch::max, ret::reference); - m.def("sum", &ir::dispatch::sum, ret::reference); - m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); - // math - m.def("umulhi", &ir::dispatch::umulhi, ret::reference); - m.def("exp", &ir::dispatch::exp, ret::reference); - m.def("log", &ir::dispatch::log, ret::reference); - m.def("cos", &ir::dispatch::cos, ret::reference); - m.def("sin", &ir::dispatch::sin, ret::reference); - m.def("sqrt", &ir::dispatch::sqrt, ret::reference); - // utilities - m.def("clock", &ir::dispatch::clock, ret::reference); - m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference); - // internal (debugging only) - m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); - m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); - m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); -} /*****************************************************************************/ /* Python bindings for triton::ir */ @@ -631,16 +551,86 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; + py::enum_(m, "CACHE_MODIFIER") + .value("NONE", ir::load_inst::NONE) + .value("CA", ir::load_inst::CA) + .value("CG", ir::load_inst::CG) + .export_values(); + + py::enum_(m, "EVICTION_POLICY") + .value("NORMAL", ir::load_inst::NORMAL) + .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) + .value("EVICT_LAST", ir::load_inst::EVICT_LAST) + .export_values(); + + py::enum_(m, "REDUCE_OP") + .value("ADD", ir::reduce_inst::ADD) + .value("FADD", ir::reduce_inst::FADD) + .value("MIN", ir::reduce_inst::MIN) + .value("MAX", ir::reduce_inst::MAX) + .value("FMIN", ir::reduce_inst::FMIN) + .value("FMAX", ir::reduce_inst::FMAX) + .value("XOR", ir::reduce_inst::XOR); + + py::enum_(m, "ATOMIC_OP") + .value("ADD", ir::atomic_rmw_op_t::Add) + .value("FADD", ir::atomic_rmw_op_t::FAdd) + .value("AND", ir::atomic_rmw_op_t::And) + .value("OR", ir::atomic_rmw_op_t::Or) + .value("XOR", ir::atomic_rmw_op_t::Xor) + .value("XCHG", ir::atomic_rmw_op_t::Xchg) + .value("MAX", ir::atomic_rmw_op_t::Max) + .value("MIN", ir::atomic_rmw_op_t::Min) + .value("UMIN", ir::atomic_rmw_op_t::UMin) + .value("UMAX", ir::atomic_rmw_op_t::UMax); + py::class_(m, "context") .def(py::init<>()); - auto value = py::class_(m, "value"); - value.def_property("name", &ir::value::get_name, &ir::value::set_name); - value.def_property_readonly("type", &ir::value::get_type); + py::class_(m, "value") + .def("multiple_of", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::multiple_of, val); + } else + throw std::runtime_error("multiple_of"); + }) + .def("max_contiguous", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::max_contiguous, val); + } else + throw std::runtime_error("max_contiguous"); + }) + .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { + if (auto *instr = dynamic_cast(self)) + instr->set_fdiv_ieee_rounding(val); + else + throw std::runtime_error("set_fdiv_ieee_rounding"); + }) + .def("is_phi", [](ir::value *self) { + if (auto *pn = dynamic_cast(self)) + return true; + return false; + }) + .def("ops", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) { + return instr->ops(); + } + throw std::runtime_error("cannot use ops()"); + }) + .def("replace_all_uses_with", &ir::value::replace_all_uses_with) + .def("erase_from_parent", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) + return instr->erase_from_parent(); + throw std::runtime_error("cannot use erase_from_parent"); + }) + .def_property("name", &ir::value::get_name, &ir::value::set_name) + .def_property_readonly("type", &ir::value::get_type); py::class_(m, "user"); - py::class_(m, "constant"); + py::class_(m, "constant") + .def("get_null_value", &ir::constant::get_null_value, ret::reference) + .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); py::class_(m, "undef") .def("get", &ir::undef_value::get, ret::reference); @@ -651,18 +641,17 @@ void init_triton_ir(py::module &&m) { .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value); + .def_property_readonly("value", &ir::constant_fp::get_value) + .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); - py::class_(m, "instruction"); - py::class_(m, "phi_node"); + py::class_(m, "instruction") + .def("get_parent", [](ir::instruction *self) { + return self->get_parent(); + }, ret::reference); + py::class_(m, "phi_node") + .def("add_incoming", &ir::phi_node::add_incoming); py::class_(m, "type") - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("get_int_width", &ir::type::get_integer_bitwidth) - - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) @@ -677,35 +666,39 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) - .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) - .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) - .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) + .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) + .def("get_block_shapes", &ir::type::get_block_shapes) + + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) + .def("is_struct", &ir::type::is_struct_ty) .def("is_void", &ir::type::is_void_ty) + .def("is_bool", &ir::type::is_bool_ty) .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_fp16_ty) .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) - .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) - .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) - .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) - .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) - .def("is_struct", &ir::type::is_struct_ty) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference); + .def_property_readonly("context", &ir::type::get_context, ret::reference) + .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) + .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) + .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); py::class_(m, "function_type") .def_property_readonly("ret_ty", &ir::function_type::get_return_ty) @@ -723,21 +716,20 @@ void init_triton_ir(py::module &&m) { .def("get", &ir::struct_type::get, ret::reference) .def_property_readonly("num_types", &ir::struct_type::get_num_types); - py::class_(m, "value_constructor") - .def(py::init()) - .def("seal_block", &ir::value_constructor::seal_block) - .def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value) - .def("set_type", &ir::value_constructor::set_type) - .def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference) - .def("get_values", &ir::value_constructor::get_values, ret::reference) - .def("set_values", &ir::value_constructor::set_values); - py::class_(m, "module") .def(py::init()) .def("has_function", &ir::module::has_function) .def("get_function", &ir::module::get_function, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) .def("reset_ret_ty", &ir::module::reset_ret_ty) + .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { + const auto metadatas = self->get_metadatas(); + auto it = metadatas.find(name); + if (it != metadatas.end()) + if (auto *instr = dynamic_cast(value)) { + instr->set_metadata(it->second.first, it->second.second); + } + }) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; @@ -768,6 +760,13 @@ void init_triton_ir(py::module &&m) { py::class_(m, "basic_block") .def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr) + .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) + .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { + ir::basic_block::iterator it = self->get_first_non_phi(); + if (it == self->end()) + return nullptr; + return *it; + }, ret::reference) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "bb_iterator"); @@ -783,22 +782,168 @@ void init_triton_ir(py::module &&m) { .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) .def("ret", &ir::builder::create_ret, ret::reference) - .def("get_insert_point", &ir::builder::get_insert_point) - .def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point) + // insertion block/point, insert points are represented as (*bb, *instr) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + .def("get_insert_point", [](ir::builder *self) { + ir::basic_block *bb = self->get_insert_block(); + ir::basic_block::iterator it = self->get_insert_point(); + ir::instruction *instr = it == bb->end() ? nullptr : *it; + return std::make_pair(bb, instr); + }, ret::reference) + .def("set_insert_point", [](ir::builder *self, std::pair pt) { + ir::basic_block *bb = pt.first; + ir::instruction *instr = pt.second; + if (instr) { + if (bb != instr->get_parent()) + throw std::runtime_error("invalid insertion point, instr not in bb"); + self->set_insert_point(instr); + } else { + assert(bb); + self->set_insert_point(bb); + } + }) + // Constants + .def("get_int1", &ir::builder::get_int1, ret::reference) + .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) + .def("get_uint32", &ir::builder::get_int32, ret::reference) + .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) + .def("get_uint64", &ir::builder::get_int64, ret::reference) + .def("get_float16", &ir::builder::get_float16, ret::reference) + .def("get_float32", &ir::builder::get_float32, ret::reference) + .def("get_range", &ir::builder::get_range, ret::reference) + // Types + .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) + .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference) + .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference) + .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) + .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference) + .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) + .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) + .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) + .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) + .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) + .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) + // terminator instructions + .def("create_br", &ir::builder::create_br, ret::reference) + .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) + .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) + // Cast instructions + .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) + .def("create_cast", &ir::builder::create_cast, ret::reference) + .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) + .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) + .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) + .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) + .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) + .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) + .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) + .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) + .def("create_downcast", &ir::builder::create_downcast, ret::reference) + // phi + .def("create_phi", &ir::builder::create_phi, ret::reference) + // Binary instructions + .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) + .def("create_fmul", &ir::builder::create_fmul, ret::reference) + .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) + .def("create_frem", &ir::builder::create_frem, ret::reference) + .def("create_fadd", &ir::builder::create_fadd, ret::reference) + .def("create_fsub", &ir::builder::create_fsub, ret::reference) + .def("create_mul", &ir::builder::create_mul, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) + .def("create_udiv", &ir::builder::create_udiv, ret::reference) + .def("create_srem", &ir::builder::create_srem, ret::reference) + .def("create_urem", &ir::builder::create_urem, ret::reference) + .def("create_add", &ir::builder::create_add, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sub", &ir::builder::create_sub, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_shl", &ir::builder::create_shl, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_lshr", &ir::builder::create_lshr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_ashr", &ir::builder::create_ashr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + // GEP + .def("create_gep", &ir::builder::create_gep, ret::reference) + // Comparison (int) + .def("create_icmp", &ir::builder::create_icmp, ret::reference) + .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) + .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) + .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) + .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) + .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) + .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) + .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) + .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) + .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) + .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) + // Comparison (float) + .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) + .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) + .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) + .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) + .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) + .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) + .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) + .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) + .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) + .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) + .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) + .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) + .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) + // Logical + .def("create_and", &ir::builder::create_and, ret::reference) + .def("create_xor", &ir::builder::create_xor, ret::reference) + .def("create_or", &ir::builder::create_or, ret::reference) + // Input/Output + .def("create_load", &ir::builder::create_load, ret::reference) + .def("create_store", &ir::builder::create_store, ret::reference) + .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) + .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) + // Block instruction + .def("create_splat", &ir::builder::create_splat, ret::reference) + .def("create_reshape", &ir::builder::create_reshape, ret::reference) + .def("create_cat", &ir::builder::create_cat, ret::reference) + .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) + // atomic + .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) + .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + // Utilities + .def("create_clock", &ir::builder::create_clock, ret::reference) + .def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference) + + // Built-in instruction + .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) + .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) + .def("create_exp", &ir::builder::create_exp, ret::reference) + .def("create_cos", &ir::builder::create_cos, ret::reference) + .def("create_sin", &ir::builder::create_sin, ret::reference) + .def("create_log", &ir::builder::create_log, ret::reference) + .def("create_dot", &ir::builder::create_dot, ret::reference) + .def("create_trans", &ir::builder::create_trans, ret::reference) + .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) + .def("create_reduce", &ir::builder::create_reduce, ret::reference) + .def("create_select", &ir::builder::create_select, ret::reference) // struct .def("insert_value", &ir::builder::create_insert_value, ret::reference) .def("extract_value", &ir::builder::create_extract_value, ret::reference) - // constants - .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", &ir::builder::get_int32, ret::reference) - .def("get_int64", &ir::builder::get_int64, ret::reference) - .def("get_uint32", &ir::builder::get_uint32, ret::reference) - .def("get_uint64", &ir::builder::get_uint64, ret::reference) - .def("get_float16", &ir::builder::get_float16, ret::reference) - .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference); + // Intrinsics + // These have no place in the IR, and hopefully they can be removed at some point + .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) + .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) + .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) + .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) + .def("create_barrier", &ir::builder::create_barrier, ret::reference) + .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) + .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); } void init_triton(py::module &m) { @@ -806,5 +951,4 @@ void init_triton(py::module &m) { init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); - init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 1df3a0b49..f30b203bb 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -37,7 +37,7 @@ matmul_data = { (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.680}, + (2048, 2048, 2048): {'float16': 0.695}, (4096, 4096, 4096): {'float16': 0.831}, (8192, 8192, 8192): {'float16': 0.849}, # tall-skinny diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 08e5b721f..c8bfedab4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,5 +1,4 @@ # flake8: noqa: F821,F841 -import copy import itertools import re from typing import Optional, Union @@ -12,7 +11,7 @@ from numpy.random import RandomState import triton import triton._C.libtriton.triton as _triton import triton.language as tl -from triton.code_gen import TensorWrapper, reinterpret +from triton.code_gen import JITFunction, TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -993,11 +992,17 @@ def test_noop(device='cuda'): @pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') ]) def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + spec_type = None + + def cache_hook(*args, **kwargs): + nonlocal spec_type + spec_type = kwargs["compile"]["arg_types"][0][1] + JITFunction.cache_hook = cache_hook @triton.jit def kernel(VALUE, X): @@ -1006,11 +1011,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non x = torch.tensor([3.14159], device='cuda') pgm = kernel[(1, )](value, x) - # Parse out the type of the 'VALUE' parameter from the Triton IR. - triton_ir = pgm.asm['ttir'] - ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir) - ir_value_type = None if ir_value_match is None else ir_value_match.group(2) - assert ir_value_type == value_type + JITFunction.cache_hook = None + assert spec_type == value_type @pytest.mark.parametrize( @@ -1045,13 +1047,13 @@ def stub(X, alpha, grid_0, grid_1, grid_2): tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) -def test_dyn_par(cond=True, device='cuda'): - n_pids = 10 - # pids = torch.arange(n_pids, device=device) - # alpha = 2.0 - # x_ref = pids * alpha - x_tri = torch.full((10,), fill_value=-1., device=device) - # cond = torch.tensor([cond], device=device) - stub[(1,)](x_tri, 3.14, n_pids, 1, 1) - print(x_tri) - # triton.testing.assert_almost_equal(x_ref, x_tri) +# def test_dyn_par(cond=True, device='cuda'): +# n_pids = 10 +# # pids = torch.arange(n_pids, device=device) +# # alpha = 2.0 +# # x_ref = pids * alpha +# x_tri = torch.full((10,), fill_value=-1., device=device) +# # cond = torch.tensor([cond], device=device) +# stub[(1,)](x_tri, 3.14, n_pids, 1, 1) +# print(x_tri) +# # triton.testing.assert_almost_equal(x_ref, x_tri) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 8ac01bcc8..d866d6983 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,4 +1,5 @@ import os +import re import shutil import pytest @@ -102,3 +103,30 @@ def test_specialize(mode): for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target + + +@pytest.mark.parametrize("value, value_type", [ + (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), + (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), + (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + cache_str = None + + def get_cache_str(*args, **kwargs): + nonlocal cache_str + cache_str = kwargs['key'].split('-') + triton.code_gen.JITFunction.cache_hook = get_cache_str + reset_tmp_dir() + x = torch.tensor([3.14159], device='cuda') + kernel[(1, )](value, x) + triton.code_gen.JITFunction.cache_hook = None + + cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) + spec_type = None if cache_str_match is None else cache_str_match.group(1) + assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index f9982939c..37ba46efc 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,7 +6,8 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ + JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e6102366a..b44fc244e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast import builtins import functools @@ -11,7 +13,7 @@ import tempfile import textwrap import time import warnings -from typing import Dict +from typing import Dict, Set, Tuple, Union import torch from filelock import FileLock @@ -21,26 +23,26 @@ import triton._C.libtriton.triton as _triton from .tools.disasm import extract -def mangle_ty(type): - if type.is_ptr(): - return 'P' + mangle_ty(type.element) - if type.is_int(): - return 'i' + str(type.get_int_width()) - if type.is_fp8(): +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + return 'i' + str(ty.int_bitwidth) + if ty.is_fp8(): return 'fp8' - if type.is_fp16(): + if ty.is_fp16(): return 'fp16' - if type.is_bf16(): + if ty.is_bf16(): return 'bf16' - if type.is_fp32(): + if ty.is_fp32(): return 'fp32' - if type.is_fp64(): + if ty.is_fp64(): return 'fp64' - if type.is_void(): + if ty.is_void(): return 'V' - if type.is_block(): - elt = mangle_ty(type.scalar) - shape = '_'.join(map(str, type.shape)) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) return f'{elt}S{shape}S' assert False, "Unsupport type" @@ -56,8 +58,38 @@ def mangle_fn(name, arg_tys, constants): return ret -class CodeGenerator(ast.NodeVisitor): +def is_triton_tensor(value): + return isinstance(value, triton.language.tensor) + + +class ValueConstructor: + def __init__(self, module, builder, gscope) -> None: + self.gscope = gscope + self.lscope = dict() + self.builder = builder + self.module = module + # [name, bb] => triton.language.tensor + self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} + # bb => {name => phi} + self.incomplete_phis = {} + self.sealed_blocks: Set[_triton.ir.basic_block] = set() + # + self.builtins = { + 'range': range, + 'min': triton.language.minimum, + 'float': float, + 'int': int, + 'print': print, + 'isinstance': isinstance, + 'getattr': getattr, + } + def get_value(self, name): + ''' This function: + 1. make sure `name` is defined + 2. if `name` is triton.language.tensor, get stored tensor by calling + `self._get_tensor()` + ''' # search node.id in local scope ret = None if name in self.lscope: @@ -70,21 +102,123 @@ class CodeGenerator(ast.NodeVisitor): ret = self.builtins[name] else: raise ValueError(f'{name} is not defined') - if isinstance(ret, triton.language.block): - handle = self.value_constructor.get_value(name) - return triton.language.block(handle) + if is_triton_tensor(ret): + return self._get_tensor(name, self.builder.get_insert_block()) return ret - def set_value(self, name, value): - if isinstance(value, _triton.ir.value): - value = triton.language.block(value) - if isinstance(value, triton.language.block): - self.value_constructor.set_value(name, value.handle) - self.value_constructor.set_type(name, value.handle.type) + def set_value(self, name: str, + value: Union[triton.language.tensor, triton.language.constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' self.lscope[name] = value + if isinstance(value, triton.language.tensor): + self._set_value(name, self.builder.get_insert_block(), value) - def is_triton_object(self, value): - return isinstance(value, triton.language.block) + # + # SSA-construction + # + def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: + # local value numbering + if (name, bb) in self.lvalues: + return self.lvalues[(name, bb)] + # global value numbering + saved_insert_point = self.builder.get_insert_point() + result = self._get_tensor_recursive(name, bb) + self.builder.set_insert_point(saved_insert_point) + return result + + def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: + preds = bb.get_predecessors() + type = self.lscope[name].type + # some preds haven't been filled, create a phi as a proxy of the value + if bb not in self.sealed_blocks: + result = self._make_phi(type, len(preds), bb) + if bb in self.incomplete_phis: + self.incomplete_phis[bb][name] = result + else: + self.incomplete_phis[bb] = {name: result} + elif len(preds) == 1: + # one predecessor: no phi needed, try get value from pred + result = self._get_tensor(name, preds[0]) + elif len(preds) == 0: + result = self._get_tensor(name, None) + else: # multiple preds + phi = self._make_phi(type, len(preds), bb) + self._set_value(name, bb, phi) + result = self._add_phi_operands(name, phi) + self._set_value(name, bb, result) + return result + + # returns a new phi tensor, which encausulate an ir.phi_node + def _make_phi(self, + type: triton.language.dtype, + num_values: int, + bb: _triton.ir.basic_block) -> triton.language.tensor: + instr = bb.get_first_non_phi() + self.builder.set_insert_point((bb, instr)) + ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) + if instr: + self.builder.set_insert_block(bb) + return triton.language.tensor(ir_phi, type) + + # complete a phi node. (TODO: rename this as _complete_phis?) + # Note: since we try to remove tryival phi, the return tensor might not be a phi + def _add_phi_operands(self, name: str, + phi: triton.language.tensor) -> triton.language.tensor: + bb = phi.handle.get_parent() + for pred in bb.get_predecessors(): + v = self._get_tensor(name, pred) + phi.handle.add_incoming(v.handle, pred) + phi = self._try_remove_trivial_phi(phi) + return phi + + def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: + self.lvalues[(name, bb)] = value + # TODO: why we need this? + self.module.set_instr_metadata(name, value.handle) + + def _seal_block(self, bb: _triton.ir.basic_block): + # complete all incomplete phis + if bb in self.incomplete_phis: + for name, phi in self.incomplete_phis[bb].items(): + result = self._add_phi_operands(name, phi) + # it's possible that this phi is trivial + if self._get_tensor(name, bb).handle == phi.handle: + self._set_value(name, bb, result) + del self.incomplete_phis[bb] + self.sealed_blocks.add(bb) + + def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: + unique_handles = {op for op in phi.handle.ops() if op != phi.handle} + if len(unique_handles) != 1: # non-trivial phi + return phi + v = unique_handles.pop() + phi.handle.replace_all_uses_with(v) + phi.handle.erase_from_parent() + # TODO: remove trivial phis recursively + return triton.language.tensor(v, phi.type) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, prototypes=None, module=None, is_kernel=False): + self.prototypes = dict() if prototypes is None else prototypes + self.builder = _triton.ir.builder(context) + self.module = _triton.ir.module('', self.builder) if module is None else module + self.prototype = prototype + self.attributes = attributes + self.constants = constants + self.last_node = None + self.is_kernel = is_kernel + + self.value_constructor = ValueConstructor(self.module, self.builder, gscope) + + # + # AST visitor + # def visit_compound_statement(self, stmts): for stmt in stmts: @@ -93,27 +227,6 @@ class CodeGenerator(ast.NodeVisitor): break return stmts and isinstance(stmt, ast.Return) - def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False): - self.builder = _triton.ir.builder(context) - self.value_constructor = _triton.ir.value_constructor(self.builder) - self.module = _triton.ir.module('', self.builder) if module is None else module - self.prototype = prototype - self.gscope = gscope - self.lscope = dict() - self.attributes = attributes - self.constants = constants - self.last_node = None - self.is_kernel = is_kernel - self.builtins = { - 'range': range, - 'min': triton.language.minimum, - 'float': float, - 'int': int, - 'print': print, - 'isinstance': isinstance, - 'getattr': getattr, - } - def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -127,16 +240,10 @@ class CodeGenerator(ast.NodeVisitor): def visit_Return(self, node): ret = self.visit(node.value) if ret is None: - return self.builder.ret_void() - if isinstance(ret, _triton.ir.value): - ret = self.builder.ret(ret) - return ret - if isinstance(ret, triton.language.block): - ret = ret.handle - if isinstance(ret, triton.language.constexpr): - ret = triton.language.core._to_ir(ret, self.builder) - # TODO: should return tl.block - return self.builder.ret(ret) + return triton.language.tensor(self.builder.ret_void(), triton.language.void) + ret = triton.language.core._to_tensor(ret, self.builder) + ret = triton.language.tensor(self.builder.ret(ret.handle), ret.type) + return ret def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) @@ -152,8 +259,9 @@ class CodeGenerator(ast.NodeVisitor): init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) # initialize function - fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants) - fn = self.module.get_or_insert_function(fn_name, self.prototype) + fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) + self.prototypes[fn_name] = self.prototype + fn = self.module.get_or_insert_function(fn_name, self.prototype.to_ir(self.builder)) fn.set_is_kernel(self.is_kernel) arg_values = [] idx = 0 @@ -171,23 +279,24 @@ class CodeGenerator(ast.NodeVisitor): attr = _triton.ir.attribute(attr, self.attributes[i]) fn.add_attr(idx + 1, attr) fn.args[idx].name = arg_name - arg_values.append(fn.args[idx]) + arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) idx += 1 insert_pt = self.builder.get_insert_block() entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) self.builder.set_insert_block(entry) - self.value_constructor.seal_block(entry) + self.value_constructor._seal_block(entry) for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) + self.value_constructor.set_value(arg_name, arg_value) # visit function body has_ret = self.visit_compound_statement(node.body) # finalize if not has_ret: self.builder.ret_void() else: - self.module.reset_ret_ty(fn_name, self.last_ret.type) - # self.module.reset_ret_type(node.name) + # a bit hacky: we only know the return type at the last moment so we update type info here + self.module.reset_ret_ty(fn_name, self.last_ret.type.to_ir(self.builder)) + self.prototype.ret_type = self.last_ret.type self.builder.set_insert_block(insert_pt) def visit_arguments(self, node): @@ -208,13 +317,13 @@ class CodeGenerator(ast.NodeVisitor): value = self.visit(node.value) # constexpr if annotation == triton.language.constexpr: - if target in self.lscope: + if target in self.value_constructor.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') if not isinstance(value, triton.language.constexpr): value = triton.language.constexpr(value) - self.lscope[target] = value - return self.lscope[target] + self.value_constructor.lscope[target] = value + return self.value_constructor.lscope[target] # default: call visit_Assign return self.visit_Assign(node) @@ -229,19 +338,21 @@ class CodeGenerator(ast.NodeVisitor): names = [names] if not isinstance(values, tuple): values = [values] - if isinstance(values[0], _triton.ir.value): - struct = values[0] - ty = struct.type - if ty.is_struct(): - values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)] + if isinstance(values[0], triton.language.tensor) \ + and isinstance(values[0].type, triton.language.tuple_type): + struct = values[0].handle + tys = values[0].type.element_types + values = [self.builder.extract_value(struct, i) for i in range(len(tys))] + values = [triton.language.tensor(v, ty) for v, ty in zip(values, tys)] assert len(values) == len(names) for name, value in zip(names, values): + # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value - if not isinstance(value, triton.language.block): - value = triton.language.core._to_ir(value, self.builder) - self.set_value(name, value) + if not isinstance(value, triton.language.tensor): + value = triton.language.core._to_tensor(value, self.builder) + self.value_constructor.set_value(name, value) def visit_AugAssign(self, node): name = node.target.id @@ -249,12 +360,12 @@ class CodeGenerator(ast.NodeVisitor): rhs = ast.BinOp(lhs, node.op, node.value) assign = ast.Assign(targets=[node.target], value=rhs) self.visit(assign) - return self.get_value(name) + return self.value_constructor.get_value(name) def visit_Name(self, node): if type(node.ctx) == ast.Store: return node.id - return self.get_value(node.id) + return self.value_constructor.get_value(node.id) def visit_Store(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -266,23 +377,22 @@ class CodeGenerator(ast.NodeVisitor): args = [self.visit(x) for x in node.elts] mode = type(args[0]) # tuple of values -- create a struct - if len(args) > 1 and mode == triton.language.block\ + if len(args) > 1 and mode == triton.language.tensor\ and all([type(arg) == mode for arg in args]): - args = [arg.handle for arg in args] - tys = [arg.type for arg in args] - struct_ty = _triton.ir.struct_type.get(tys, True) - ret = _triton.ir.undef.get(struct_ty) + tuple_ty = triton.language.tuple_type([arg.type for arg in args]) + ret = _triton.ir.undef.get(tuple_ty.to_ir(self.builder)) for i, arg in enumerate(args): - ret = self.builder.insert_value(ret, arg, i) + ret = self.builder.insert_value(ret, arg.handle, i) + ret = triton.language.tensor(ret, tuple_ty) return ret return tuple(args) def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value fn = { ast.Add: '__add__', @@ -298,9 +408,9 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - if self.is_triton_object(lhs): + if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -308,15 +418,15 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if isinstance(cond, triton.language.block): + if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self.value_constructor.seal_block(then_bb) + self.value_constructor._seal_block(then_bb) if else_bb: - self.value_constructor.seal_block(else_bb) + self.value_constructor._seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -331,7 +441,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self.value_constructor.seal_block(endif_bb) + self.value_constructor._seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -356,9 +466,9 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) @@ -372,9 +482,9 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] - if self.is_triton_object(lhs): + if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -385,21 +495,21 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.core.constexpr): + if isinstance(op, triton.language.constexpr): op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if self.is_triton_object(op): + if is_triton_tensor(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): cond = self.visit(node.test) @@ -410,9 +520,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self.value_constructor.seal_block(stop_bb) - self.value_constructor.seal_block(loop_bb) - self.value_constructor.seal_block(next_bb) + self.value_constructor._seal_block(stop_bb) + self.value_constructor._seal_block(loop_bb) + self.value_constructor._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -422,7 +532,7 @@ class CodeGenerator(ast.NodeVisitor): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) - if self.is_triton_object(lhs): + if is_triton_tensor(lhs): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] @@ -431,7 +541,7 @@ class CodeGenerator(ast.NodeVisitor): def visit_For(self, node): iterator = self.visit(node.iter.func) - if iterator != self.builtins['range']: + if iterator != self.value_constructor.builtins['range']: raise RuntimeError('Only `range` iterator currently supported') # static for loops: all iterator arguments are constexpr iter_args = [self.visit(arg) for arg in node.iter.args] @@ -442,7 +552,7 @@ class CodeGenerator(ast.NodeVisitor): range = iterator(*iter_args) if len(range) <= 10: for i in iterator(*iter_args): - self.lscope[node.target.id] = triton.language.constexpr(i) + self.value_constructor.lscope[node.target.id] = triton.language.constexpr(i) self.visit_compound_statement(node.body) for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) @@ -465,8 +575,8 @@ class CodeGenerator(ast.NodeVisitor): step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): self.visit(step_node) @@ -481,9 +591,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self.value_constructor.seal_block(stop_bb) - self.value_constructor.seal_block(loop_bb) - self.value_constructor.seal_block(next_bb) + self.value_constructor._seal_block(stop_bb) + self.value_constructor._seal_block(loop_bb) + self.value_constructor._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -514,7 +624,7 @@ class CodeGenerator(ast.NodeVisitor): from inspect import getcallargs args = getcallargs(fn.fn, *args, **kws) args = [args[name] for name in fn.arg_names] - args = [arg if isinstance(arg, triton.language.block) + args = [arg if isinstance(arg, triton.language.tensor) else triton.language.constexpr(arg) for arg in args] # generate function def attributes = dict() @@ -523,25 +633,24 @@ class CodeGenerator(ast.NodeVisitor): # generate call args = [None if i in constexprs else arg for i, arg in enumerate(args)] arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in arg_vals] + arg_types = [arg.type for arg in args if arg is not None] fn_name = mangle_fn(fn.__name__, arg_types, constants) # generate function def if necessary if not self.module.has_function(fn_name): - ret_type = _triton.ir.type.get_void(self.builder.context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) + ret_type = triton.language.void + prototype = triton.language.function_type(ret_type, arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module) + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, prototypes=self.prototypes, module=self.module) generator.visit(fn.parse()) symbol = self.module.get_function(fn_name) ret = self.builder.call(symbol, arg_vals) - if not ret.type.is_void() and not ret.type.is_struct(): - ret = triton.language.block(ret) + if not ret.type.is_void(): + ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type) return ret # built-in function - if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core: + if sys.modules[fn.__module__] is triton.language.core: ret = fn(*args, _builder=self.builder, **kws) - if fn in self.builtins.values(): + if fn in self.value_constructor.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] ret = fn(*args, **kws) @@ -698,7 +807,7 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] - if isinstance(obj, triton.language.core.constexpr): + if isinstance(obj, triton.language.constexpr): obj = obj.value if isinstance(obj, int): if -2**31 <= obj < 2**31: @@ -730,34 +839,34 @@ class Kernel: return 'scalar', name @staticmethod - def _to_triton_ir(context, obj): + def _to_triton_ir(obj): which, name = obj type_map = { - 'I': _triton.ir.type.get_int32, - 'L': _triton.ir.type.get_int64, - 'f': _triton.ir.type.get_fp32, - 'B': _triton.ir.type.get_int1, - 'f8': _triton.ir.type.get_fp8, - 'f16': _triton.ir.type.get_fp16, - 'bf16': _triton.ir.type.get_bf16, - 'f32': _triton.ir.type.get_fp32, - 'f64': _triton.ir.type.get_fp64, - 'i1': _triton.ir.type.get_int1, - 'i8': _triton.ir.type.get_int8, - 'i16': _triton.ir.type.get_int16, - 'i32': _triton.ir.type.get_int32, - 'i64': _triton.ir.type.get_int64, - 'u8': _triton.ir.type.get_uint8, - 'u16': _triton.ir.type.get_uint16, - 'u32': _triton.ir.type.get_uint32, - 'u64': _triton.ir.type.get_uint64, + 'I': triton.language.int32, + 'L': triton.language.int64, + 'f': triton.language.float32, + 'B': triton.language.int1, + 'f8': triton.language.float8, + 'f16': triton.language.float16, + 'bf16': triton.language.bfloat16, + 'f32': triton.language.float32, + 'f64': triton.language.float64, + 'i1': triton.language.int1, + 'i8': triton.language.int8, + 'i16': triton.language.int16, + 'i32': triton.language.int32, + 'i64': triton.language.int64, + 'u8': triton.language.uint8, + 'u16': triton.language.uint16, + 'u32': triton.language.uint32, + 'u64': triton.language.uint64, } # convert torch.Tensor to Triton IR pointers if which == 'ptr': - elt_ty = type_map[name](context) - return _triton.ir.type.make_ptr(elt_ty, 1) + elt_ty = type_map[name] + return triton.language.pointer_type(elt_ty, 1) # default path returns triton.ir.type directly - return type_map[name](context) + return type_map[name] @staticmethod def pow2_divisor(N): @@ -1121,9 +1230,9 @@ class JITFunction: # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] - ret_type = _triton.ir.type.get_void(context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) + arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] + ret_type = triton.language.void + prototype = triton.language.function_type(ret_type, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index cad4edfe4..6aa1b68cd 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,63 +1,36 @@ +from __future__ import annotations + +from enum import Enum from functools import wraps +from typing import List import triton -from triton._C.libtriton.triton import frontend, ir +from . import semantic +from triton._C.libtriton.triton import ir -# convert block/dtype to ir values -def _to_ir(x, builder): +def _to_tensor(x, builder): if isinstance(x, bool): - return builder.get_int1(x) + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values elif isinstance(x, int): if -2**31 <= x < 2**31: - return builder.get_int32(x) + return tensor(builder.get_int32(x), int32) elif 2**31 <= x < 2**32: - return builder.get_uint32(x) + return tensor(builder.get_uint32(x), uint32) elif -2**63 <= x < 2**63: - return builder.get_int64(x) + return tensor(builder.get_int64(x), int64) elif 2**63 <= x < 2**64: - return builder.get_uint64(x) + return tensor(builder.get_uint64(x), uint64) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - return builder.get_float32(x) + return tensor(builder.get_float32(x), float32) elif isinstance(x, constexpr): - return _to_ir(x.value, builder) - elif isinstance(x, block): - return x.handle - elif isinstance(x, dtype): - return x.handle(builder) - return x - - -def _patch(fn): - def _from_ir(x): - if isinstance(x, ir.value): - if x.type.is_void(): - return None - return block(x) + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): return x - - def wrapper(*args, **kwargs): - builder = args[-1] - assert isinstance(builder, ir.builder) - args = [_to_ir(x, builder) for x in args] - # for i, arg in enumerate(args): - # if arg is None: - # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") - kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} - ret = fn(*args, **kwargs) - if isinstance(ret, tuple): - return map(_from_ir, ret) - return _from_ir(ret) - - return wrapper - - -for name in dir(frontend): - fn = getattr(frontend, name) - if callable(fn): - setattr(frontend, name, _patch(fn)) + assert False, f'cannot convert {x} to tensor' def builtin(fn): @@ -72,20 +45,147 @@ def builtin(fn): class dtype: - def __init__(self, init): - self.init = init + SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return self.name == 'fp8' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def is_void(self): + return self.name == 'void' + + def is_block(self): + return False + + def is_ptr(self): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name,)) @property - def name(self) -> str: - # The init functions are named something like 'get_int8'. Strip the prefix. - nom = self.init.__name__ - prefix = 'get_' - assert nom.startswith(prefix) - return nom[len(prefix):] + def scalar(self): + return self - def handle(self, builder): - ctx = builder.context - return self.init(ctx) + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8' or self.name == 'uint8': + return builder.get_int8_ty() + elif self.name == 'int16' or self.name == 'uint16': + return builder.get_int16_ty() + elif self.name == 'int32' or self.name == 'uint32': + return builder.get_int32_ty() + elif self.name == 'int64' or self.name == 'uint64': + return builder.get_int64_ty() + elif self.name == 'fp8': + return builder.get_fp8_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to covert {self} to ir type') def __str__(self): return self.name @@ -99,36 +199,124 @@ class dtype: return f'triton.language.{self.name}' -class pointer_dtype: - def __init__(self, element_ty): +class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty + self.address_space = address_space - def handle(self, builder): - return ir.type.make_ptr(self.element_ty.handle(builder), 1) + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List[int]): + self.element_ty = element_ty + # FIXME: + # block_type's shape is a list of int + # while tensor's shape is a list of constexpr + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return f'<{self.shape}, {self.element_ty}>' + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: + self.ret_type = ret_type + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_type}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) + + +class tuple_type(dtype): + def __init__(self, element_types: List[dtype]) -> None: + self.element_types = element_types + + def __str__(self): + return f'<{self.element_types}>' + + def to_ir(self, builder: ir.builder): + ir_element_types = [ty.to_ir(builder) for ty in self.element_types] + return ir.struct_type.get(ir_element_types, True) + # scalar types -int1 = dtype(ir.type.get_int1) -int8 = dtype(ir.type.get_int8) -int16 = dtype(ir.type.get_int16) -int32 = dtype(ir.type.get_int32) -int64 = dtype(ir.type.get_int64) -uint8 = dtype(ir.type.get_uint8) -uint16 = dtype(ir.type.get_uint16) -uint32 = dtype(ir.type.get_uint32) -uint64 = dtype(ir.type.get_uint64) -float8 = dtype(ir.type.get_fp8) -float16 = dtype(ir.type.get_fp16) -bfloat16 = dtype(ir.type.get_bf16) -float32 = dtype(ir.type.get_fp32) -float64 = dtype(ir.type.get_fp64) +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8 = dtype('fp8') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') # pointer types -pi32_t = pointer_dtype(int32) +pi32_t = pointer_type(int32) # ----------------------- # constexpr @@ -149,7 +337,6 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" - # def __add__(self, other): return self.value + other.value @@ -219,31 +406,33 @@ class constexpr: return self.value(*args, **kwds) -class block: +class tensor: + # infer dtype from ir type @staticmethod - def _init_dtype(ir_type): + def _to_dtype(ir_type): + # block type + if ir_type.is_block(): + scalar_ty = tensor._to_dtype(ir_type.scalar) + return block_type(scalar_ty, ir_type.get_block_shapes()) + # pointer type + if ir_type.is_ptr(): + element_ty = tensor._to_dtype(ir_type.element) + return pointer_type(element_ty) # primitive type + if ir_type.is_void(): return void if ir_type.is_int1(): return int1 if ir_type.is_int8(): return int8 if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 - if ir_type.is_uint8(): return uint8 - if ir_type.is_uint16(): return uint16 - if ir_type.is_uint32(): return uint32 - if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 - # pointer type - if ir_type.is_ptr(): - element_ty = block._init_dtype(ir_type.element) - return pointer_dtype(element_ty) - raise ValueError(f"Unsupported type {ir_type}") + raise ValueError(f"Unsupported type {ir_type.repr()}") - def __init__(self, handle): + def __init__(self, handle, type: dtype): # IR handle self.handle = handle # Block shape @@ -254,9 +443,9 @@ class block: for s in self.shape: self.numel *= s self.numel = constexpr(self.numel) - # Data-type wrapper - self.dtype = block._init_dtype(self.handle.type.scalar) - # Shape is a constexpr + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: @@ -265,116 +454,139 @@ class block: @builtin def __add__(self, other, _builder=None): - return frontend.add(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - return frontend.sub(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) def __rsub__(self, other, _builder=None): - return frontend.sub(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) @builtin def __mul__(self, other, _builder=None): - return frontend.mul(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - return frontend.truediv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) def __rtruediv__(self, other, _builder=None): - return frontend.truediv(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - return frontend.floordiv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) @builtin def __mod__(self, other, _builder=None): - return frontend.mod(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) # unary operators @builtin def __neg__(self, _builder=None): - return frontend.minus(self, _builder) + return semantic.minus(self, _builder) @builtin def __invert__(self, _builder=None): - return frontend.invert(self, _builder) + return semantic.invert(self, _builder) # bitwise operators @builtin def __and__(self, other, _builder=None): - return frontend.and_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) @builtin def __or__(self, other, _builder=None): - return frontend.or_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) @builtin def __xor__(self, other, _builder=None): - return frontend.xor_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) @builtin def __lshift__(self, other, _builder=None): - return frontend.shl(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) @builtin def __rshift__(self, other, _builder=None): - return frontend.lshr(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.lshr(self, other, _builder) # comparison operators # > @builtin def __gt__(self, other, _builder=None): - return frontend.greater_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - return frontend.greater_than(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - return frontend.greater_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + @builtin def __rge__(self, other, _builder=None): - return frontend.greater_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - return frontend.less_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - return frontend.less_than(other, self, _builder) + return semantic.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - return frontend.less_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - return frontend.less_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - return frontend.equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) @builtin def __ne__(self, other, _builder=None): - return frontend.not_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) @builtin def __getitem__(self, slices, _builder=None): @@ -389,20 +601,25 @@ class block: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = frontend.reshape(self, dst_shape, _builder) + ret = semantic.reshape(self, dst_shape, _builder) return ret @builtin def to(self, dtype, bitcast=False, _builder=None): - dtype = dtype.handle(_builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value if bitcast: - return frontend.bitcast(self, dtype, _builder) - return frontend.cast(self, dtype, _builder) + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder) # ----------------------- # SPMD Programming Model # ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v @builtin @@ -414,13 +631,14 @@ def program_id(axis, _builder=None): :type axis: int """ # if axis == -1: - # pid0 = frontend.program_id(0, _builder) - # pid1 = frontend.program_id(1, _builder) - # pid2 = frontend.program_id(2, _builder) - # npg0 = frontend.num_programs(0, _builder) - # npg1 = frontend.num_programs(0, _builder) + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 - return frontend.program_id(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) @builtin @@ -431,7 +649,8 @@ def num_programs(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - return frontend.num_programs(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) # ----------------------- @@ -449,13 +668,15 @@ def arange(start, end, _builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - return frontend.arange(start, end, _builder) + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) @builtin def zeros(shape, dtype, _builder=None): """ - Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints @@ -468,7 +689,8 @@ def zeros(shape, dtype, _builder=None): if not isinstance(d.value, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") shape = [x.value for x in shape] - return frontend.zeros(shape, dtype, _builder) + dtype = _constexpr_to_value(dtype) + return semantic.zeros(shape, dtype, _builder) # ----------------------- @@ -481,25 +703,25 @@ def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. - :param input: The first input block. + :param input: The first input tensor. :type input: Block - :param other: The second input block. + :param other: The second input tensor. :type other: Block """ - return frontend.broadcast(input, other, _builder) + return semantic.broadcast_impl_value(input, other, _builder) @builtin def broadcast_to(input, shape, _builder=None): """ - Tries to broadcast the given block to a new :code:`shape`. + Tries to broadcast the given tensor to a new :code:`shape`. - :param input: The input block. + :param input: The input tensor. :type input: Block :param shape: The desired shape. :type shape: Tuple[int] """ - return frontend.broadcast_to(input, shape, _builder) + return semantic.broadcast_impl_shape(input, shape, _builder) @builtin @@ -507,27 +729,27 @@ def cat(input, other, _builder=None): """ Concatenate the given blocks - :param input: The first input block. + :param input: The first input tensor. :type input: - :param other: The second input block. + :param other: The second input tensor. :type other: """ - return frontend.cat(input, other, _builder) + return semantic.cat(input, other, _builder) @builtin def reshape(input, shape, _builder=None): """ - Tries to reshape the given block to a new shape. + Tries to reshape the given tensor to a new shape. - :param input: The input block. + :param input: The input tensor. :type input: :param shape: The desired shape. :type shape: Tuple[int] """ shape = [x.value for x in shape] - return frontend.reshape(input, shape, _builder) + return semantic.reshape(input, shape, _builder) # ----------------------- @@ -542,12 +764,13 @@ def dot(input, other, allow_tf32=True, _builder=None): The two blocks must be two dimensionals and have compatible inner dimensions. - :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} - :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - return frontend.dot(input, other, allow_tf32, _builder) + allow_tf32 = _constexpr_to_value(allow_tf32) + return semantic.dot(input, other, allow_tf32, _builder) # ----------------------- @@ -558,7 +781,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ - Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. + Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. @@ -573,24 +796,36 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) + # mask, other can be constexpr + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :param pointer: The memory locations where the elements of :code:`value` are stored. :type pointer: Block of dtype=triton.PointerDType - :param value: The block of elements to be stored. + :param value: The tensor of elements to be stored. :type value: Block :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - return frontend.store(pointer, value, mask, _builder) + # value can be constexpr + value = _to_tensor(value, _builder) + if mask is not None: + mask = _to_tensor(mask, _builder) + return semantic.store(pointer, value, mask, _builder) # ----------------------- @@ -621,49 +856,58 @@ def _add_atomic_docstr(name): @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): - return frontend.atomic_cas(pointer, cmp, val, _builder) + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(cmp, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): - return frontend.atomic_xchg(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xchg(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): - return frontend.atomic_add(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_add(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, _builder=None): - return frontend.atomic_max(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_max(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, _builder=None): - return frontend.atomic_min(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_min(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, _builder=None): - return frontend.atomic_and(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_and(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, _builder=None): - return frontend.atomic_or(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_or(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, _builder=None): - return frontend.atomic_xor(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -674,7 +918,7 @@ def atomic_xor(pointer, val, mask=None, _builder=None): @builtin def where(condition, x, y, _builder=None): """ - Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. @@ -688,7 +932,10 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - return frontend.where(condition, x, y, _builder) + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) # ----------------------- @@ -697,12 +944,15 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): - return frontend.umulhi(x, y, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.umulhi(x, y, _builder) @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): - return frontend.fdiv(x, y, ieee_rounding, _builder) + ieee_rounding = _constexpr_to_value(ieee_rounding) + return semantic.fdiv(x, y, ieee_rounding, _builder) def _add_math_1arg_docstr(name): @@ -723,31 +973,31 @@ def _add_math_1arg_docstr(name): @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): - return frontend.exp(x, _builder) + return semantic.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): - return frontend.log(x, _builder) + return semantic.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): - return frontend.cos(x, _builder) + return semantic.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): - return frontend.sin(x, _builder) + return semantic.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): - return frontend.sqrt(x, _builder) + return semantic.sqrt(x, _builder) # ----------------------- @@ -758,7 +1008,7 @@ def _add_reduction_docstr(name): def _decorator(func): docstr = """ - Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done @@ -772,25 +1022,29 @@ def _add_reduction_docstr(name): @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): - return frontend.max(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.max(input, axis, _builder) @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): - return frontend.min(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.min(input, axis, _builder) @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): - return frontend.sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.sum(input, axis, _builder) @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): - return frontend.xor_sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.xor_sum(input, axis, _builder) # ----------------------- # Utilities @@ -799,12 +1053,12 @@ def xor_sum(input, axis, _builder=None): @builtin def globaltimer(_builder=None): - return frontend.globaltimer(_builder) + return semantic.globaltimer(_builder) @builtin def clock(_builder=None): - return frontend.clock(_builder) + return semantic.clock(_builder) # ----------------------- # Internal for debugging @@ -813,7 +1067,7 @@ def clock(_builder=None): @builtin def debug_barrier(_builder=None): - return frontend.debug_barrier(_builder) + return semantic.debug_barrier(_builder) @builtin @@ -821,7 +1075,8 @@ def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - return frontend.multiple_of(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.multiple_of(input, value) @builtin @@ -829,7 +1084,8 @@ def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - return frontend.max_contiguous(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.max_contiguous(input, value) # ----------------------- @@ -859,9 +1115,9 @@ def minimum(x, y): """ Computes the element-wise minimum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x < y, x, y) @@ -872,9 +1128,9 @@ def maximum(x, y): """ Computes the element-wise maximum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x > y, x, y) @@ -900,7 +1156,7 @@ def ravel(x): """ Returns a contiguous flattened view of :code:`x` - :param x: the input block + :param x: the input tensor :type x: Block """ return triton.language.reshape(x, [x.numel]) @@ -947,21 +1203,21 @@ def zeros_like(input): # ----------------------- -class LaunchProxy: +# class LaunchProxy: - def __init__(self, fn, args, constants, grid, num_warps) -> None: - self.args = args - self.grid = grid - self.constants = constants - self.num_warps = num_warps - self.fn = fn +# def __init__(self, fn, args, constants, grid, num_warps) -> None: +# self.args = args +# self.grid = grid +# self.constants = constants +# self.num_warps = num_warps +# self.fn = fn -@builtin -def launch(fn, args, grid, num_warps=None, _builder=None): - constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)} - args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)] - grid = [_to_ir(x, builder=_builder) for x in grid] - if num_warps is None: - num_warps = _to_ir(4, builder=_builder) - return LaunchProxy(fn, args, constants, grid, num_warps) +# @builtin +# def launch(fn, args, grid, num_warps=None, _builder=None): +# constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)} +# args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)] +# grid = [_to_ir(x, builder=_builder) for x in grid] +# if num_warps is None: +# num_warps = _to_ir(4, builder=_builder) +# return LaunchProxy(fn, args, constants, grid, num_warps) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py new file mode 100644 index 000000000..10d20fbb3 --- /dev/null +++ b/python/triton/language/semantic.py @@ -0,0 +1,1052 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Tuple + +from . import core as tl +from triton._C.libtriton.triton import ir + + +# Create custom exception that prints message "hello" +class IncompatibleTypeErrorimpl(Exception): + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorimpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + assert False + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + assert False + # 4 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorimpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, + arithmetic_check=True, div_or_mod=False + ) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + assert False + + +def sub(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), + input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + assert False + + +def mul(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + assert False + + +def truediv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + assert False + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + assert False + + +def fdiv(input: tl.tensor, + other: tl.tensor, + ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise ValueError("both operands of fdiv must have floating poscalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + ret.set_fdiv_ieee_rounding(ieee_rounding) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + assert False + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def lshr(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, + builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, + builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + assert False + + +def greater_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + assert False + + +def equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + assert False + + +def not_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + assert False + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + shape = [end - start] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.get_range(start, end), ret_ty) + + +def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + _0 = ir.constant.get_null_value(dtype.to_ir(builder)) + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_splat(_0, shape), ret_ty) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def reshape(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("cannot reshape block of different shape") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.type.shape[1:] == rhs.type.shape[1:] + ret_shape = [lhs.type.shape[0] + rhs.type.shape[0]] + ret_ty = tl.block_type(lhs.type.scalar, ret_shape) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_ty) + + +def broadcast_impl_shape(input: tl.tensor, + shape: List[int], + builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + if len(lhs_shape) != len(rhs_shape): + raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") + ret_shape = [] + for i in range(len(lhs_shape)): + left = lhs_shape[i] + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif right == 1: + ret_shape.append(left) + elif left == right: + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + +def bitcast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), + dst_ty) + + +def cast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # bf16 <=> (not fp32) + if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ + (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # FP Truncation + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # FP Extension + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Int cast + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + return tl.tensor(builder.create_int_cast(input.handle, + dst_ty.to_ir(builder), sign_extend), + dst_ty) + + # Float to Int + if src_sca_ty.is_floating() and dst_sca_ty.is_int(): + # TODO: is this correct? + if dst_sca_ty.is_bool(): + return tl.tensor(builder.create_fp_to_ui(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_fp_to_si(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # int => float + if src_sca_ty.is_int() and dst_sca_ty.is_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # ptr => int + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), + tl.tensor(builder.get_int64(0), tl.int64), + builder) + + if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + # Ptr . Ptr + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + # * . Bool + if dst_sca_ty.is_bool(): + if src_sca_ty.is_ptr(): + input = cast(input, tl.int64, builder) + other = builder.get_int64(0) + if src_ty.is_bool(): + other = builder.create_splat(other, src_ty.get_block_shapes()) + return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def load(ptr: tl.tensor, + mask: Optional[tl.tensor], + other: Optional[tl.tensor], + cache_modifier: str, + eviction_policy: str, + is_volatile: bool, + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + if other: + other = cast(other, ptr.type.scalar.element_ty, builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cache modifier + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + + # eviction policy + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + dst_ty = elt_ty + + if not mask and not other: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), + dst_ty) + if not mask: + raise ValueError("`other` cannot be provided without `mask`") + + if not other: + other_ir = ir.undef.get(elt_ty.to_ir(builder)) + if ptr.type.is_block(): + other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes()) + other = tl.tensor(other_ir, dst_ty) + + return tl.tensor(builder.create_masked_load(ptr.handle, + mask.handle, + other.handle, + cache, eviction, is_volatile), + dst_ty) + + +def store(ptr: tl.tensor, + val: tl.tensor, + mask: Optional[tl.tensor], + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cast to target data-type + val = cast(val, elt_ty, builder) + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, + cmp: tl.tensor, + val: tl.tensor, + builder: ir.builder) -> tl.tensor: + # TODO: type checking + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_min(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + i_ptr.handle, + i_val.handle, + and_(mask, pos, builder).handle), + i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + i_ptr.handle, + i_val.handle, + and_(mask, neg, builder).handle), + i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_add(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_and(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_or(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xor(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xchg(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def dot(lhs: tl.tensor, + rhs: tl.tensor, + allow_tf32: bool, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + if lhs.type.scalar.is_int(): + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + else: + _0 = builder.get_float32(0) + ret_scalar_ty = tl.float32 + M = lhs.type.shape[0] + N = rhs.type.shape[1] + _0 = builder.create_splat(_0, [M, N]) + ret_ty = tl.block_type(ret_scalar_ty, [M, N]) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + +def where(condition: tl.tensor, + x: tl.tensor, + y: tl.tensor, + builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) + y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) + + # TODO: we need to check x's and y's shape? + x_ty = x.type.scalar + y_ty = y.type.scalar + ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) + x = cast(x, ty, builder) + y = cast(y, ty, builder) + if x.type.is_block(): + ret_ty = tl.block_type(ty, x.type.shape) + else: + ret_ty = ty + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reductions +# ===----------------------------------------------------------------------=== + +def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, + FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: + scalar_ty = input.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: + input = cast(input, tl.int32, builder) + + # get result type + shape = input.type.shape + ret_shape = [] + for i, s in enumerate(shape): + if i != axis: + ret_shape.append(s) + if len(ret_shape) == 0: + res_ty = scalar_ty + else: + res_ty = tl.block_type(scalar_ty, ret_shape) + + if scalar_ty.is_floating(): + return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) + elif scalar_ty.is_int(): + return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) + assert False + + +def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) + + +def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) + + +def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) + + +def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) + + +# ----------------------- +# Utilities +# ----------------------- + +def clock(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_clock(), tl.int64) + + +def globaltimer(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_globaltimer, tl.int64) + + +# ===----------------------------------------------------------------------=== +# Math +# ===----------------------------------------------------------------------=== + +def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + x, y = binary_op_type_checking_impl(x, y, builder) + return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) + + +def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_exp(x.handle), x.type) + + +def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_log(x.handle), x.type) + + +def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_cos(x.handle), x.type) + + +def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sin(x.handle), x.type) + + +def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sqrt(x.handle), x.type) + + +## + +def multiple_of(x: tl.tensor, value: int) -> tl.tensor: + x.handle.multiple_of(value) + return x + + +def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: + x.handle.max_contiguous(value) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(''), tl.void) From 6424771f552b4c391706df74cf707eb29d639c40 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 7 Apr 2022 09:42:03 -0700 Subject: [PATCH 087/215] [CI] Documentation fixup --- .github/workflows/documentation.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index e921709ba..7dfb0a489 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -18,6 +18,11 @@ jobs: with: ref: 'gh-pages' + - name: Clear docs + run: | + rm -r /tmp/triton-docs + continue-on-error: true + - name: Checkout branch uses: actions/checkout@v1 @@ -31,7 +36,6 @@ jobs: run: | git branch # update docs - rm -r /tmp/triton-docs; mkdir /tmp/triton-docs; mv docs/_build/html/* /tmp/triton-docs/ git checkout gh-pages From 14b0fd4cfba4128fd0d55199bb1649e81a6da754 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 7 Apr 2022 12:11:32 -0700 Subject: [PATCH 088/215] [FRONTEND] Added possibility for users to customize current stream query (#492) --- python/triton/code_gen.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b44fc244e..51e3577ae 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -22,6 +22,8 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream + def mangle_ty(ty): if ty.is_ptr(): @@ -787,6 +789,7 @@ class OutOfResources(Exception): class Kernel: + @staticmethod def _type_name(obj): type_names = { @@ -915,28 +918,24 @@ class Kernel: raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") # handle annotations for pos, _type in self.fn.annotations.items(): + assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" wargs[pos] = _type(wargs[pos]) # check that tensors are on GPU. for arg in wargs: if hasattr(arg, 'data_ptr'): assert arg.is_cuda, "All tensors must be on GPU!" - # query device index and cuda stream + # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() torch.cuda.set_device(device) + # query compute capability cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1]) - # # query stream - # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # # and require different wheels for different torch versions -- undesirable! - # bits = torch._C._cuda_getCurrentStream(device) - # mask = 1 << 47 - # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - stream = torch.cuda.current_stream(device).cuda_stream - # make key for cache - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, - self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) + cache_key = self.fn.cache_key + cc + # query current stream + stream = current_stream(device) + return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, + device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, + grid) class Launcher: From 76bfac9f15ebc3e3cadcacefce71918b12928407 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 12 Apr 2022 00:02:54 -0700 Subject: [PATCH 089/215] [FRONTEND] Improved constexpr handling (#493) --- lib/ir/builder.cc | 1 - python/test/unit/language/test_core.py | 54 +++++++++++++------- python/triton/code_gen.py | 43 ++++++++++------ python/triton/language/core.py | 68 +++----------------------- 4 files changed, 70 insertions(+), 96 deletions(-) diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index d79e5d9d1..4060f23bb 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c8bfedab4..4570a5c61 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1032,28 +1032,44 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' kernel[(1, )](value, x) else: kernel[(1, )](value, x) -# ------------------------- -# test dynamic parallelism -# ------------------------- -@triton.jit -def mult(x, alpha): - tl.store(x + tl.program_id(0), alpha) +# ---------------- +# test constexpr +# ---------------- + +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + x = numpy_random((1,), dtype_str="float32") + y = numpy_random((1,), dtype_str="float32") + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x) + y_tri = to_triton(y) + z_tri = to_triton(np.empty((1,), dtype=z.dtype)) + kernel[(1,)](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri)) -@triton.jit -def stub(X, alpha, grid_0, grid_1, grid_2): - tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) +def test_constexpr_shape(): + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) -# def test_dyn_par(cond=True, device='cuda'): -# n_pids = 10 -# # pids = torch.arange(n_pids, device=device) -# # alpha = 2.0 -# # x_ref = pids * alpha -# x_tri = torch.full((10,), fill_value=-1., device=device) -# # cond = torch.tensor([cond], device=device) -# stub[(1,)](x_tri, 3.14, n_pids, 1, 1) -# print(x_tri) -# # triton.testing.assert_almost_equal(x_ref, x_tri) + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 51e3577ae..311fb85f0 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -390,12 +390,14 @@ class CodeGenerator(ast.NodeVisitor): return tuple(args) def visit_BinOp(self, node): + # visit operand lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value + is_lhs_constexpr = isinstance(lhs, triton.language.constexpr) + is_rhs_constexpr = isinstance(rhs, triton.language.constexpr) + lhs = lhs.value if is_lhs_constexpr else lhs + rhs = rhs.value if is_rhs_constexpr else rhs + # get function name fn = { ast.Add: '__add__', ast.Sub: '__sub__', @@ -410,6 +412,10 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] + # return a new constexpr if both arg are constexprs + if is_lhs_constexpr and is_rhs_constexpr: + return triton.language.constexpr(getattr(lhs, fn)(rhs)) + # call operator if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) elif is_triton_tensor(rhs): @@ -468,14 +474,16 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value + is_lhs_constexpr = isinstance(lhs, triton.language.constexpr) + is_rhs_constexpr = isinstance(rhs, triton.language.constexpr) + lhs = lhs.value if is_lhs_constexpr else lhs + rhs = rhs.value if is_rhs_constexpr else rhs + # handle `is`` and `is not`` if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) if type(node.ops[0]) == ast.IsNot: return triton.language.constexpr(lhs is not rhs) + # function name fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -484,29 +492,32 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] + # return a new constexpr if both arg are constexprs + if is_lhs_constexpr and is_rhs_constexpr: + return triton.language.constexpr(getattr(lhs, fn)(rhs)) + # call operator if is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) elif is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: - return getattr(lhs, fn)(rhs) + assert False def visit_UnaryOp(self, node): op = self.visit(node.operand) if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): - op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if is_triton_tensor(op): - return getattr(op, fn)(_builder=self.builder) - return getattr(op, fn)() + if isinstance(op, triton.language.constexpr): + return triton.language.constexpr(getattr(op.value, fn)()) + assert is_triton_tensor(op) + return getattr(op, fn)(_builder=self.builder) def visit_While(self, node): current_bb = self.builder.get_insert_block() @@ -656,6 +667,10 @@ class CodeGenerator(ast.NodeVisitor): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] ret = fn(*args, **kws) + if isinstance(ret, (bool, int, float)): + ret = triton.language.core.constexpr(ret) + else: + ret = triton.language.core._to_tensor(ret, self.builder) # special case: dynamic parallelism # in this case the core primitive returns a proxy # if isinstance(ret, triton.language.core.LaunchProxy): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6aa1b68cd..75d75e8ea 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -337,68 +337,6 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" - def __add__(self, other): - return self.value + other.value - - def __radd__(self, other): - return other.value + self.value - - def __sub__(self, other): - return self.value - other.value - - def __rsub__(self, other): - return other.value - self.value - - def __mul__(self, other): - return self.value * other.value - - def __rmul__(self, other): - return other.value * self.value - - def __truediv__(self, other): - return self.value / other.value - - def __rtruediv__(self, other): - return other.value / self.value - - def __floordiv__(self, other): - return self.value // other.value - - def __rfloordiv__(self, other): - return other.value // self.value - - # - - def __gt__(self, other): - return self.value > other.value - - def __rgt__(self, other): - return other.value > self.value - - def __ge__(self, other): - return self.value >= other.value - - def __rge__(self, other): - return other.value >= self.value - - def __lt__(self, other): - return self.value < other.value - - def __rlt__(self, other): - return other.value < self.value - - def __le__(self, other): - return self.value <= other.value - - def __rle__(self, other): - return other.value <= self.value - - def __eq__(self, other): - return self.value == other.value - - def __ne__(self, other): - return self.value != other.value - def __bool__(self): return bool(self.value) @@ -496,6 +434,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.mod(self, other, _builder) + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + # unary operators @builtin def __neg__(self, _builder=None): @@ -564,6 +507,7 @@ class tensor: @builtin def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) return semantic.less_than(other, self, _builder) # <= From 25f66895083982aa7c9a2ccf6600ebc1d9199d2b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 13 Apr 2022 11:45:55 -0700 Subject: [PATCH 090/215] [FRONTEND] rename current stream monkey patch (#495) --- python/triton/code_gen.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 311fb85f0..7553bee55 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -22,7 +22,12 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract -current_stream = lambda device: torch.cuda.current_stream(device).cuda_stream + +def current_cuda_stream(device_idx=0): + # Torch's torch.cuda.current_stream() is slow. We provide this + # function to give the user an opportunity to monkey-patch their + # own faster current stream lookup. + return torch.cuda.current_stream().cuda_stream def mangle_ty(ty): @@ -947,7 +952,7 @@ class Kernel: cc = str(cc[0]) + '-' + str(cc[1]) cache_key = self.fn.cache_key + cc # query current stream - stream = current_stream(device) + stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) From dc4d40faec8fdb8b9d46ab2e36c303a8d3e34f29 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 14 Apr 2022 10:26:48 -0700 Subject: [PATCH 091/215] [FRONTEND] now mangle constexpr float containing "e-" --- python/triton/code_gen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 7553bee55..5fd1c1be6 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -61,6 +61,7 @@ def mangle_fn(name, arg_tys, constants): mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') + mangled_constants = mangled_constants.replace("e-", '_em_') ret = f'{name}__{mangled_arg_names}__{mangled_constants}' return ret From 5c7122004c25266f5dfd65c5613813107732278c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 14 Apr 2022 17:33:44 -0700 Subject: [PATCH 092/215] [TUTORIALS] Tutorial shouldn't expose `clock`. Just removed it. --- python/tutorials/01-vector-add.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 51de7ac6c..d684106f1 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -24,11 +24,9 @@ def add_kernel( y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector - time_start_ptr, time_end_ptr, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process # NOTE: `constexpr` so it can be used as a shape value ): - tl.atomic_min(time_start_ptr, tl.clock()) # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 @@ -47,7 +45,6 @@ def add_kernel( output = x + y # Write x + y back to DRAM tl.store(output_ptr + offsets, output, mask=mask) - tl.atomic_max(time_end_ptr, tl.clock()) # %% @@ -56,8 +53,6 @@ def add_kernel( def add(x: torch.Tensor, y: torch.Tensor): - time_start = torch.zeros(1, dtype=torch.int64, device='cuda') - time_end = torch.zeros(1, dtype=torch.int64, device='cuda') # We need to preallocate the output output = torch.empty_like(x) assert x.is_cuda and y.is_cuda and output.is_cuda @@ -70,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments - add_kernel[grid](x, y, output, n_elements, time_start, time_end, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output From 073be1d2ee38353a6e71c535b3460bc160fd3822 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 14 Apr 2022 19:30:02 -0700 Subject: [PATCH 093/215] [FRONTEND] check that tensors have power-of-two number of elements (#499) --- python/triton/language/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 75d75e8ea..046e60e0e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -380,6 +380,9 @@ class tensor: self.numel = 1 for s in self.shape: self.numel *= s + is_pow2 = (self.numel and (not(self.numel & (self.numel - 1)))) + if not is_pow2: + raise ValueError("Triton tensors must have a power-of-two number of elements") self.numel = constexpr(self.numel) self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type From 7d6c504e8db37bd070ad31dd0a542edb1a5ad37c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Apr 2022 22:40:10 -0700 Subject: [PATCH 094/215] [TESTING] Added testing utilities for fixing clock and using cuda-memcheck (#500) --- python/triton/testing.py | 76 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/python/triton/testing.py b/python/triton/testing.py index fbca719ff..bfcd6ef6b 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,6 +1,8 @@ +import functools import os import subprocess import sys +from contextlib import contextmanager import torch @@ -358,6 +360,80 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + def decorator(test_fn): + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checkng failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + return wrapper + return decorator + + +def nvsmi_attr(attrs): + attrs = ",".join(attrs) + cmd = [ + "nvidia-smi", + "-i", + "0", + "--query-gpu=" + attrs, + "--format=csv,noheader,nounits", + ] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(",") + ret = [int(x) for x in ret] + return ret + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ] + ) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ] + ) + cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None): if not backend: From 0cc3b1129beb9f99235a7bc07d27f2565d53ffe7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 21 Apr 2022 23:56:01 -0700 Subject: [PATCH 095/215] [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) --- include/triton/codegen/selection/generator.h | 4 +++ lib/codegen/selection/generator.cc | 26 +++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index e3191efb1..a4f1d33af 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -4,6 +4,7 @@ #define _TRITON_SELECTION_GENERATOR_H_ #include "triton/ir/visitor.h" +#include "triton/ir/instructions.h" #include "triton/codegen/analysis/layout.h" #include @@ -261,6 +262,9 @@ private: /// Record prefetch instrs that needs to be moved std::map> prefetch_latch_to_bb_; + + // Eviction policies + std::map policies_; }; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 0e6ae4539..c60350060 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -91,6 +91,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() +#define i64_ty builder_->getInt64Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) // constants @@ -778,6 +779,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); + bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL; // ----- // create inline asm string // ----- @@ -789,8 +791,9 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; - if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; + if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; + if (has_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size @@ -801,7 +804,9 @@ void generator::visit_load_inst(ir::load_inst* x){ } asm_oss << "}"; asm_oss << ", [ $" << n_words + 1; // load - asm_oss << " + " << in_off << "];"; // constant offset + asm_oss << " + " << in_off << "]"; // constant offset + if (has_evict_policy) asm_oss << ", $" << n_words + 2; + asm_oss << ";"; bool has_other = other && (other != UndefValue::get(other->getType())); std::vector others; // handle `other` values for indices where the mask @@ -822,7 +827,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if(ConstantInt* cst = dyn_cast(v)) asm_oss << "0x" << std::hex << cst->getSExtValue(); else{ - asm_oss << "$" << n_words + 2 + ii; + asm_oss << "$" << n_words + has_evict_policy + 2 + ii; others.push_back(v); } asm_oss.flags(flags); @@ -837,6 +842,8 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); + if (has_evict_policy) + arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- // create inline ASM constraints @@ -851,6 +858,8 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } + if (has_evict_policy) + asm_cstrt += ",l"; // --- // finally call inline ASM // --- @@ -858,6 +867,10 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); + if (has_evict_policy) + args.push_back(policies_.at(x->get_eviction_policy())); + + Value *_ret = call(inlineAsm, args); // --- // extract and store return values @@ -2880,6 +2893,13 @@ void generator::visit_function(ir::function* fn) { bbs_[block] = dst_block; } builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); + // create policies + for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ + std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; + std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;"; + InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false); + policies_[evict] = call(iasm); + } // initialize layouts for(auto x: layouts_->get_all()){ visit_layout(x.second); From bda209002e2bb758b78b74f61b70083a44f0e695 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 23 Apr 2022 13:18:33 -0700 Subject: [PATCH 096/215] [BACKEND][CODEGEN] vectorization bugfix (#502) --- include/triton/codegen/analysis/align.h | 3 + lib/codegen/analysis/align.cc | 41 ++++++++++- lib/codegen/selection/generator.cc | 10 +++ python/test/unit/runtime/test_comm.py | 98 ------------------------- 4 files changed, 51 insertions(+), 101 deletions(-) delete mode 100644 python/test/unit/runtime/test_comm.py diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h index 2393603cb..513868aea 100644 --- a/include/triton/codegen/analysis/align.h +++ b/include/triton/codegen/analysis/align.h @@ -12,6 +12,7 @@ namespace ir { class phi_node; class splat_inst; class cast_inst; + class cmp_inst; class reshape_inst; class broadcast_inst; class binary_operator; @@ -35,6 +36,7 @@ private: std::vector populate_is_constant_reshape(ir::reshape_inst* x); std::vector populate_is_constant_broadcast(ir::broadcast_inst* x); std::vector populate_is_constant_binop(ir::binary_operator* x); + std::vector populate_is_constant_cmp(ir::cmp_inst* x); std::vector populate_is_constant_gep(ir::getelementptr_inst* x); std::vector populate_is_constant_default(ir::value* v); std::vector populate_is_constant(ir::value *v); @@ -65,6 +67,7 @@ public: void run(ir::module &mod); unsigned get(ir::value* v, unsigned ax) const; std::vector contiguous(ir::value* v) const; + std::vector get_cst_info(ir::value* v) const; private: std::map> is_constant_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index e92d3b6ee..bd68755f1 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -129,6 +129,33 @@ std::vector align::populate_is_constant_broadcast(ir::broadcast return add_to_cache(x, result, is_constant_); } +std::vector align::populate_is_constant_cmp(ir::cmp_inst* x) { + auto x_shapes = get_shapes(x); + std::vector result; + ir::value* lhs_op = x->get_operand(0); + ir::value* rhs_op = x->get_operand(1); + auto lhs = populate_is_constant(lhs_op); + auto rhs = populate_is_constant(rhs_op); + auto lhs_max_contiguous = populate_max_contiguous(lhs_op); + auto rhs_max_contiguous = populate_max_contiguous(rhs_op); + auto lhs_multiple_of = populate_starting_multiple(lhs_op); + auto rhs_multiple_of = populate_starting_multiple(rhs_op); + for(size_t d = 0; d < x_shapes.size(); d++) { + cst_info ax = {1, 0}; + // if lhs (resp. rhs) is a range of M value starting at a multiple of N + // and rhs (resp. lhs) is made of M constants that are multiples of N + // then comparisons have M constants + int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]); + if(rhs[d].num_cst % lhs_max_contiguous[d] == 0) + ax = {std::min(min_multiple, lhs_max_contiguous[d]), 0}; + else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0) + ax = {std::min(min_multiple, rhs_max_contiguous[d]), 0}; + result.push_back(ax); + } + return add_to_cache(x, result, is_constant_); +} + + std::vector align::populate_is_constant_binop(ir::binary_operator* x) { auto x_shapes = get_shapes(x); std::vector result; @@ -136,12 +163,15 @@ std::vector align::populate_is_constant_binop(ir::binary_operat ir::value* rhs_op = x->get_operand(1); auto lhs = populate_is_constant(lhs_op); auto rhs = populate_is_constant(rhs_op); - auto max_contiguous = populate_max_contiguous(lhs_op); + auto lhs_max_contiguous = populate_max_contiguous(lhs_op); + auto rhs_max_contiguous = populate_max_contiguous(rhs_op); + auto lhs_multiple_of = populate_starting_multiple(lhs_op); + auto rhs_multiple_of = populate_starting_multiple(rhs_op); for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax; if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ // todo might not be entirely true - unsigned num_constants = gcd(max_contiguous[d], rhs[d].value); + unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); ax = {num_constants, 0}; } else @@ -184,6 +214,8 @@ std::vector align::populate_is_constant(ir::value *v) { return populate_is_constant_broadcast(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_binop(x); + if(auto *x = dynamic_cast(v)) + return populate_is_constant_cmp(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_gep(x); return populate_is_constant_default(v); @@ -511,12 +543,15 @@ std::vector align::contiguous(ir::value* v) const { return max_contiguous_.at(v); } +std::vector align::get_cst_info(ir::value* v) const { + return is_constant_.at(v); +} + void align::populate(ir::value *v) { populate_is_constant(v); populate_starting_multiple(v); populate_max_contiguous(v); - } void align::run(ir::module &mod) { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index c60350060..e4723d86b 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -744,6 +744,11 @@ void generator::visit_load_inst(ir::load_inst* x){ if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); + if(mx){ + size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; + max_eq = std::max(max_eq, 1); + aln = std::min(aln, max_eq); + } auto layout = layouts_->get(x)->to_scanline(); if(layout){ size_t nts = layout->nts(ord[0]); @@ -912,6 +917,11 @@ void generator::visit_store_inst(ir::store_inst * x){ auto ord = ords_.at(x->get_pointer_operand()); size_t aln = alignment_->get(ptr_op, ord[0]); size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; + if(mx){ + size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; + max_eq = std::max(max_eq, 1); + aln = std::min(aln, max_eq); + } vec = std::min(nts, aln); } auto idxs = idxs_.at(val_op); diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py deleted file mode 100644 index ae3fb69d7..000000000 --- a/python/test/unit/runtime/test_comm.py +++ /dev/null @@ -1,98 +0,0 @@ -import subprocess - -import numpy as np -import pytest -import torch - -import triton -import triton.language as tl - - -def get_p2p_matrix(): - try: - stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii") - except subprocess.CalledProcessError: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - - lines = stdout.split("Legend")[0].split('\n')[1:] - matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2]) - if matrix.size <= 1: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - else: - return matrix - - -def get_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "OK") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -def get_non_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "NS") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -p2p_devices = get_p2p_devices() -non_p2p_devices = get_non_p2p_devices() - - -@triton.jit -def _copy(from_ptr, to_ptr, N, **meta): - pid = tl.program_id(0) - offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) - values = tl.load(from_ptr + offsets, mask=offsets < N) - tl.store(to_ptr + offsets, values, mask=offsets < N) - - -@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in p2p_devices - for device_from in p2p_devices - for device_to in p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024) - assert torch.allclose(x_from, x_to.to(device_from)) - - -@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in non_p2p_devices - for device_from in non_p2p_devices - for device_to in non_p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - with pytest.raises(RuntimeError): - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024) From 3ca792043f244a5f28eb7a347e7757edf63ae0d5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 24 Apr 2022 13:32:35 -0700 Subject: [PATCH 097/215] [TEST] Added test for vectorization --- python/test/unit/language/test_core.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4570a5c61..a7f27eaba 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -937,6 +937,24 @@ def test_load_cache_modifier(cache): assert 'ld.global.ca' in ptx assert 'ld.global.cg' not in ptx +@pytest.mark.parametrize("N", [8, 10, 11, 1024]) +def test_vectorization(N): + src = torch.empty(1024, device='cuda') + dst = torch.empty(1024, device='cuda') + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) + ptx = pgm.asm["ptx"] + if N % 4 == 0: + assert "ld.global.v4.b32" in ptx + elif N % 2 == 0: + assert "ld.global.v2.b32" in ptx + else: + assert "ld.global.b32" in ptx + # triton.testing.assert_almost_equal(dst, src[:N]) # --------------- # test store # --------------- From 7d544799a00bea26e451277e13c6d503b41d083a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Apr 2022 09:35:36 -0700 Subject: [PATCH 098/215] [BACKEND] Now disabling L2 eviction policy for sm < 80 --- lib/codegen/selection/generator.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e4723d86b..03533e559 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -784,7 +784,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); - bool has_evict_policy = x->get_eviction_policy() != ir::load_inst::NORMAL; + bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; // ----- // create inline asm string // ----- @@ -2904,6 +2904,7 @@ void generator::visit_function(ir::function* fn) { } builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); // create policies + if(tgt_->as_nvidia()->sm() >= 80) for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;"; From ae2a1ab225335b054e04f6baaf537184bb337a15 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Apr 2022 21:16:00 -0700 Subject: [PATCH 099/215] [BACKEND] Alignment pass improvements (#503) --- lib/codegen/analysis/align.cc | 33 +++++++++++++++++--------- lib/codegen/selection/generator.cc | 1 + python/test/unit/language/test_core.py | 4 +++- python/triton/code_gen.py | 6 ++--- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index bd68755f1..8dabbaf21 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -142,14 +142,17 @@ std::vector align::populate_is_constant_cmp(ir::cmp_inst* x) { auto rhs_multiple_of = populate_starting_multiple(rhs_op); for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax = {1, 0}; - // if lhs (resp. rhs) is a range of M value starting at a multiple of N - // and rhs (resp. lhs) is made of M constants that are multiples of N - // then comparisons have M constants - int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]); - if(rhs[d].num_cst % lhs_max_contiguous[d] == 0) - ax = {std::min(min_multiple, lhs_max_contiguous[d]), 0}; - else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0) - ax = {std::min(min_multiple, rhs_max_contiguous[d]), 0}; + // Examples: + // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8 + // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4 + // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16 + // + // if LHS is a range of N continuous (or equal) elements that starts at M, + // and RHS is a set of N constants that start at K + // then the result in constant in groups of gcd(M, K) + if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 || + rhs[d].num_cst % lhs[d].num_cst == 0) + ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]); result.push_back(ax); } return add_to_cache(x, result, is_constant_); @@ -170,7 +173,6 @@ std::vector align::populate_is_constant_binop(ir::binary_operat for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax; if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ - // todo might not be entirely true unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); ax = {num_constants, 0}; } @@ -433,7 +435,7 @@ std::vector align::populate_starting_multiple_binop(ir::binary_operato if(x->is_int_add_sub()) result[d] = gcd(lhs[d], rhs[d]); if(x->is_int_div()) - result[d] = 1; + result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1; if(x->is_int_rem() && rhs[d] > 1){ result[d] = gcd(lhs[d], rhs[d]); } @@ -503,6 +505,15 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { return add_to_cache(v, {1}, starting_multiple_); } +unsigned get_max_multiple(int val){ + if(val == 0) return 1 << 31; + if(val % 16 == 0) return 16; + if(val % 8 == 0) return 8; + if(val % 4 == 0) return 4; + if(val % 2 == 0) return 2; + return 1; +} + std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); @@ -518,7 +529,7 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)) return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); + return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_gep(x); if(auto *x = dynamic_cast(v)) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 03533e559..c6f064ea8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -785,6 +785,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string // ----- diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a7f27eaba..9a997d661 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -937,13 +937,15 @@ def test_load_cache_modifier(cache): assert 'ld.global.ca' in ptx assert 'ld.global.cg' not in ptx + @pytest.mark.parametrize("N", [8, 10, 11, 1024]) def test_vectorization(N): src = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda') + @triton.jit def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 5fd1c1be6..711cc87ac 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -942,9 +942,9 @@ class Kernel: assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" wargs[pos] = _type(wargs[pos]) # check that tensors are on GPU. - for arg in wargs: - if hasattr(arg, 'data_ptr'): - assert arg.is_cuda, "All tensors must be on GPU!" + # for arg in wargs: + # if hasattr(arg, 'data_ptr'): + # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() torch.cuda.set_device(device) From cd78ce48886ada99bfa5d34a7bf46f2d6ab4c346 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 29 Apr 2022 09:17:54 -0700 Subject: [PATCH 100/215] [FRONTEND] Improved error message when assigning None to non-constexpr --- python/triton/code_gen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 711cc87ac..52f9a5a60 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -358,6 +358,8 @@ class CodeGenerator(ast.NodeVisitor): # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value + if value is None: + raise ValueError(f'Cannot assign None to non-constexpr `{name}`. Please annotate as `: tl.constexpr`') if not isinstance(value, triton.language.tensor): value = triton.language.core._to_tensor(value, self.builder) self.value_constructor.set_value(name, value) @@ -719,7 +721,7 @@ class CodeGenerator(ast.NodeVisitor): ast.NodeVisitor.generic_visit(self, node) def visit_NoneType(self, node): - return None + return triton.language.constexpr(None) def visit(self, node): if node is not None: From 11a908655d42d2506c2eeb11dbbfdbf24aeaee96 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Apr 2022 14:35:09 -0700 Subject: [PATCH 101/215] [FRONTEND] Fixup --- python/triton/code_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 52f9a5a60..709ada423 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -721,7 +721,7 @@ class CodeGenerator(ast.NodeVisitor): ast.NodeVisitor.generic_visit(self, node) def visit_NoneType(self, node): - return triton.language.constexpr(None) + return None def visit(self, node): if node is not None: From 95feb10ec9538ab225c04b2b4b7b0f6b777a88cf Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 30 Apr 2022 14:25:06 -0700 Subject: [PATCH 102/215] [FRONTEND] fixup (#505) --- python/triton/language/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 046e60e0e..9af259a14 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -27,6 +27,8 @@ def _to_tensor(x, builder): elif isinstance(x, float): return tensor(builder.get_float32(x), float32) elif isinstance(x, constexpr): + if x.value is None: + return None return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x From 7c9bc5a47bcff1ffd2044bbaff0ecf819d67e343 Mon Sep 17 00:00:00 2001 From: Sriram Murali Date: Wed, 4 May 2022 20:03:37 -0700 Subject: [PATCH 103/215] [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) --- include/triton/codegen/selection/generator.h | 2 +- lib/codegen/selection/generator.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index a4f1d33af..d855d3eca 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -116,7 +116,7 @@ private: private: Type *cvt(ir::type *ty); llvm::Attribute cvt(ir::attribute attr); - llvm::StructType* packed_type(ir::value* i); + void packed_type(ir::value* i); void forward_declare(ir::function* fn); public: diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index c6f064ea8..9abf86df8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -3326,7 +3326,7 @@ void generator::finalize_phi_node(ir::phi_node *x) { } } -StructType* generator::packed_type(ir::value* i){ +void generator::packed_type(ir::value* i){ Type* dtype = cvt(i->get_type()->get_tile_element_ty()); auto* layout = dynamic_cast(layouts_->get(i)); assert(layout); From d87435e536fe35b150ba37795a6298d55b54eefd Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 5 May 2022 19:53:54 -0700 Subject: [PATCH 104/215] [TUTORIALS] Layer norm tutorial now uses residency control (#510) --- python/tutorials/05-layer-norm.py | 306 +++++++++++++++++------------- 1 file changed, 176 insertions(+), 130 deletions(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 1cefc60b9..6581c809e 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -3,11 +3,9 @@ Layer Normalization ==================== """ -import torch - import triton import triton.language as tl - +import torch try: # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it # should not be added to extras_require in setup.py. @@ -16,99 +14,113 @@ try: except ModuleNotFoundError: HAS_APEX = False +# fmt: off -# Forward Pass @triton.jit -def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, - BLOCK_SIZE: tl.constexpr): +def _layer_norm_fwd_fused( + Out, + A, + Weight, + Bias, + Mean, Rstd, + stride, N, eps, + BLOCK_SIZE: tl.constexpr, +): # position of elements processed by this program row = tl.program_id(0) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < N - # offset data pointers to start at the row of interest - X += row * stride - Y += row * stride - # load data and cast to float32 - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + Out += row * stride + A += row * stride # compute mean - mean = tl.sum(x, axis=0) / N - # compute std - xmean = tl.where(mask, x - mean, 0.) - var = tl.sum(xmean * xmean, axis=0) / N + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - ctx.save_for_backward(x, weight, bias, mean, rstd) + _layer_norm_fwd_fused[(M,)]( + out, + a_arg, + weight, + bias, + mean, rstd, + a_arg.stride(0), N, eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + ctx.save_for_backward( + a, weight, bias, mean, rstd, + ) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.eps = eps - return y + if hasattr(bias, "config"): + assert bias.config.grad_scale_name == weight.config.grad_scale_name + grad_scale_name = bias.config.grad_scale_name + else: + grad_scale_name = None + ctx.grad_scale_gain_bias_name = grad_scale_name + return out @staticmethod - def backward(ctx, dy): - x, w, b, m, v = ctx.saved_tensors + def backward(ctx, dout): + assert dout.is_contiguous() + a, weight, bias, mean, var = ctx.saved_tensors # heuristics for amount of parallel reduction stream for DG/DB - N = w.shape[0] - GROUP_SIZE_M = 64 - if N <= 8192: GROUP_SIZE_M = 96 - if N <= 4096: GROUP_SIZE_M = 128 - if N <= 1024: GROUP_SIZE_M = 256 + N = weight.shape[0] # allocate output - locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') - _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - dx = torch.empty_like(dy) + da = torch.empty_like(dout) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB - x_arg = x.reshape(-1, x.shape[-1]) + x_arg = a.reshape(-1, a.shape[-1]) M, N = x_arg.shape - _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, - x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, - GROUP_SIZE_M=GROUP_SIZE_M, - num_warps=ctx.num_warps) - grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device) + dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device) + _layer_norm_bwd_dx_fused[(M,)]( + da, + dout, + a, + weight, + mean, var, + x_arg.stride(0), M, N, + ctx.eps, + BLOCK_SIZE_N=ctx.BLOCK_SIZE, + num_warps=ctx.num_warps, + ) # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128) - return dx, None, dw, db, None + grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] + _layer_norm_bwd_dwdb[grid]( + a, dout, + mean, var, + dweight, + dbias, + M, + N, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=128, + ) + return (da, None, dweight, dbias, None, None, + None, None, None, None, + None, + None, None, None, + None, + None, None, None, + None, None, None, + None, None, None) -layer_norm = LayerNorm.apply - +def layer_norm(a, normalized_shape, weight, bias, eps): + return LayerNorm.apply(a, normalized_shape, weight, bias, eps) def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): + torch.manual_seed(0) # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -224,11 +269,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', - plot_name='layer-norm-backward', - args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} + plot_name='layer-norm', + args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'} ) ) -def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): +def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -258,4 +303,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c return gbps(ms), gbps(max_ms), gbps(min_ms) +# test_layer_norm(1151, 8192, torch.float16) bench_layer_norm.run(save_path='.', print_data=True) From cd30a99aa2efe6af67967c9d5cbd4342768d21b8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 12 May 2022 12:28:15 -0700 Subject: [PATCH 105/215] [TUTORIALS] fixed formatting --- python/tutorials/05-layer-norm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 6581c809e..a03fa2cf6 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -3,9 +3,11 @@ Layer Normalization ==================== """ +import torch + import triton import triton.language as tl -import torch + try: # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it # should not be added to extras_require in setup.py. From c736ba7c3e170b20c203bf7ef4616f931acab2f5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 12 May 2022 12:30:36 -0700 Subject: [PATCH 106/215] [TUTORIALS] Fixed formatting --- python/tutorials/05-layer-norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index a03fa2cf6..802c0aca7 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb( db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] + mask = (rows[:, None] < M) & (cols[None,:] < N) + offs = rows[:, None] * N + cols[None,:] a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) mean = tl.load(Mean + rows, mask=rows Date: Thu, 12 May 2022 12:41:25 -0700 Subject: [PATCH 107/215] [TUTORIALS] Removed #noformat in layer norm tutorial --- python/tutorials/05-layer-norm.py | 72 ++++++++++++++++--------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 802c0aca7..9880b428f 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -16,14 +16,13 @@ try: except ModuleNotFoundError: HAS_APEX = False -# fmt: off @triton.jit def _layer_norm_fwd_fused( - Out, - A, - Weight, - Bias, + Out, + A, + Weight, + Bias, Mean, Rstd, stride, N, eps, BLOCK_SIZE: tl.constexpr, @@ -37,17 +36,17 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(A + cols, mask=cols Date: Thu, 12 May 2022 13:07:39 -0700 Subject: [PATCH 108/215] [FRONTEND] Handle torch.uint8 args (#513) Co-authored-by: Philippe Tillet --- python/triton/code_gen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 709ada423..35c097017 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -822,6 +822,7 @@ class Kernel: torch.float32: 'f32', torch.float64: 'f64', torch.bool: 'i1', + torch.uint8: 'u8', torch.int8: 'i8', torch.int16: 'i16', torch.int32: 'i32', From d1a22a94e6a7a64865108a66277d5daf5e77f959 Mon Sep 17 00:00:00 2001 From: Mengchi Zhang Date: Fri, 13 May 2022 11:46:12 -0700 Subject: [PATCH 109/215] [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514) Signed-off-by: Mengchi Zhang --- include/triton/ir/instructions.h | 1 + include/triton/ir/type.h | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index c2d427ae8..9f4e18da8 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -475,6 +475,7 @@ protected: std::string get_eviction_policy_repr() const { if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; if (eviction_ == EVICT_LAST) return ".L2::evict_last"; + return ""; } EVICTION_POLICY eviction_; CACHE_MODIFIER cache_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 16a81cb5f..2c9d25294 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -21,7 +21,6 @@ class type { public: typedef std::vector block_shapes_t; -protected: typedef std::vector contained_tys_vec_t; typedef contained_tys_vec_t::iterator ty_iterator; typedef contained_tys_vec_t::const_iterator const_ty_iterator; From d35617bea13755327dff9e96a4390fd500e2f8ad Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 14 May 2022 15:26:13 -0700 Subject: [PATCH 110/215] [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) --- include/triton/codegen/selection/generator.h | 1 + lib/codegen/selection/generator.cc | 98 +++++++++++--------- python/test/unit/language/test_core.py | 13 ++- 3 files changed, 65 insertions(+), 47 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index d855d3eca..945b9b074 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -177,6 +177,7 @@ public: void visit_sqrt_inst(ir::sqrt_inst*); Value* shfl_sync(Value* acc, int32_t i); void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); + void visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral); void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 9abf86df8..cf51a3b4c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2311,60 +2311,69 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ } /** - * \brief Code Generation for `reduce` (1D case) + * \brief Code Generation for `reduce` (ND case) */ -void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { - std::map partial; +void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ + // ir::value *arg = x->get_operand(0); + analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); + std::vector shapes = layout->get_shape(); + std::vector order = layout->get_order(); + unsigned mts = layout->mts(order[0]); + unsigned nts = layout->nts(order[0]); + unsigned col_per_thread = shapes[order[0]] / mts; + auto idxs = idxs_.at(arg); + size_t n_elts = idxs.size(); + // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); - Value *acc = nullptr; - - // reduce within thread - for(indices_t idx: idxs_.at(arg)){ - Value *val = vals_[arg][idx]; - acc = !acc ? val : do_acc(acc, val); - } - // reduce within wrap - for(int i = 16; i > 0; i >>= 1) - acc = do_acc(acc, shfl_sync(acc, i)); - // pointers unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); Value* lane = urem(thread, i32(32)); - // store warp result in shared memory - add_barrier(); - store(neutral, gep(base, lane)); - add_barrier(); - store(acc, gep(base, warp)); - add_barrier(); + size_t warps_per_inner = std::max(mts/32, 1); + Value* warp_i = udiv(warp, i32(warps_per_inner)); + unsigned row_per_thread = std::max(32/mts, 1); - // reduce across warps - Value *cond = icmp_eq(warp, i32(0)); - Instruction *barrier = add_barrier(); - builder_->SetInsertPoint(barrier->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - Value* ret = load(gep(base, thread)); - for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ - Value *current = shfl_sync(ret, i); - ret = do_acc(ret, current); + for(size_t i = 0; i < n_elts/col_per_thread; i++){ + Value* acc; + // reduce within thread + for(size_t j = 0; j < col_per_thread; j++){ + Value* val = vals_[arg][idxs[i*col_per_thread + j]]; + acc = (j == 0) ? val : do_acc(acc, val); + } + // reduce within warp + for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) + acc = do_acc(acc, shfl_sync(acc, k)); + // store warp result in shared memory + Value* ret = acc; + if(mts >= 32){ + add_barrier(); + store(neutral, gep(base, lane)); + add_barrier(); + store(acc, gep(base, warp)); + add_barrier(); + // reduce across warps + Value *cond = icmp_eq(warp, i32(0)); + Instruction *barrier = add_barrier(); + builder_->SetInsertPoint(barrier->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); + Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + dummy->removeFromParent(); + builder_->SetInsertPoint(term); + ret = load(gep(base, thread)); + for(int k = (mts/32)/2; k > 0; k >>= 1){ + Value *current = shfl_sync(ret, k); + ret = do_acc(ret, current); + } + store(ret, gep(base, thread)); + builder_->SetInsertPoint(barrier->getParent()); + ret = load(gep(base, warp)); + } + vals_[x][idxs_[x][i]] = ret; } - store(ret, gep(base, thread)); - - // store first warp done - builder_->SetInsertPoint(barrier->getParent()); - ret = load(base); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ret; } -/** - * \brief Code Generation for `reduce` (ND case) - */ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { ir::value *arg = x->get_operand(0); Type *ty = cvt(x->get_type()->get_scalar_ty()); @@ -2462,8 +2471,9 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - if(arg->get_type()->get_tile_rank() == 1) - visit_reduce1d_inst(x, do_acc, neutral); + analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); + if(scanline && scanline->get_order()[0] == x->get_axis()) + visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9a997d661..77a870eea 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -676,9 +676,16 @@ def test_reduce1d(dtype_str, shape, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -@pytest.mark.parametrize("dtype_str, shape, axis", [ - (dtype, (1, 1024), 1) for dtype in ['float32', 'uint32'] -]) +reduce_configs1 = [ + (dtype, (1, 1024), axis) for dtype in ['float32', 'uint32'] + for axis in [1] +] +reduce_configs2 = [ + ('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] +] + + +@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2) def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit From abea3dc2c6a70892195d7d0d944e63d478039a92 Mon Sep 17 00:00:00 2001 From: Jiabao Lei <42935526+Karbo123@users.noreply.github.com> Date: Sun, 15 May 2022 07:21:46 +0800 Subject: [PATCH 111/215] [FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515) Co-authored-by: Philippe Tillet --- python/bench/bench_blocksparse.py | 4 ++-- python/test/unit/language/test_core.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index d678f49f8..5deb06edb 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -40,7 +40,7 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, # create op tflops = lambda ms: num_flops / ms * 1e3 if provider == 'triton': - op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT) + op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT) # inputs a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b @@ -83,7 +83,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda') if provider == 'triton': a = triton.testing.sparsify_tensor(a, layout, block) - op = triton.ops.blocksparse.softmax(layout, block) + op = triton.ops.blocksparse.softmax(layout, block, device="cuda") gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) return gbps(mean_ms), gbps(min_ms), gbps(max_ms) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 77a870eea..22c6f99f4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -644,7 +644,7 @@ def test_f16_to_f8_rounding(): ) assert torch.all( torch.logical_not(mismatch) - ), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}" + ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}" # --------------- From 205a493b10a5112ec1fccdbe9d59fe9f172e027d Mon Sep 17 00:00:00 2001 From: daadaada Date: Sun, 22 May 2022 00:45:54 +0800 Subject: [PATCH 112/215] [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520) Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas --- python/test/unit/language/test_core.py | 32 ++++++++++++++++++++++++++ python/triton/language/core.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 22c6f99f4..952922f6b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -514,9 +514,41 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) +def test_atomic_cas(): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + change_value[(1,)](Lock) + + assert(Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + data = torch.zeros((128,), device='cuda', dtype=torch.float32) + ref = torch.full((128,), 64.0) + serialized_add[(64,)](data, Lock) + triton.testing.assert_almost_equal(data, ref) + + # --------------- # test cast # --------------- + + @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ (dtype_x, dtype_z, False) for dtype_x in dtypes diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9af259a14..7ef63abba 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -806,7 +806,7 @@ def _add_atomic_docstr(name): @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): cmp = _to_tensor(cmp, _builder) - val = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) return semantic.atomic_cas(pointer, cmp, val, _builder) From 80f6a2698b56997f25d4be0a10191eb196fcb021 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Mon, 23 May 2022 13:40:08 -0700 Subject: [PATCH 113/215] [FRONTEND] Ensure version_key is called at most once (#519) Co-authored-by: hauntsaninja <> --- python/triton/code_gen.py | 52 +++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 35c097017..82ace0105 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -11,6 +11,7 @@ import subprocess import sys import tempfile import textwrap +import threading import time import warnings from typing import Dict, Set, Tuple, Union @@ -1058,27 +1059,40 @@ class Autotuner: return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) -@functools.lru_cache() +_version_key_lock = threading.Lock() +_version_key = None + + def version_key(): - import pkgutil - contents = [] - # frontend - with open(triton.code_gen.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # backend - with open(triton._C.libtriton.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # language - language_path = os.path.join(*triton.__path__, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + global _version_key + + if _version_key is not None: + return _version_key + + with _version_key_lock: + if _version_key is not None: + return _version_key + + import pkgutil + contents = [] + # frontend + with open(triton.code_gen.__file__, "rb") as f: contents += [hashlib.md5(f.read()).hexdigest()] - # ptxas version - try: - ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() - except Exception: - ptxas_version = '' - return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + _version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + return _version_key class DependenciesFinder(ast.NodeVisitor): From d5eaa8dfa0fc8a941a09da8d00b7d4c48b8ebd87 Mon Sep 17 00:00:00 2001 From: daadaada Date: Tue, 24 May 2022 23:56:36 +0800 Subject: [PATCH 114/215] Making the generated Triton IR deterministic & a script to compare cached assembly (#522) --- include/triton/codegen/analysis/liveness.h | 12 ++-- include/triton/tools/graph.h | 15 +++-- python/triton/tools/compare_asm.py | 76 ++++++++++++++++++++++ 3 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 python/triton/tools/compare_asm.py diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index a95d62a06..12232b654 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -1,12 +1,14 @@ #ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H #define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H -#include -#include -#include #include "triton/codegen/analysis/layout.h" #include "triton/tools/graph.h" +#include "llvm/ADT/MapVector.h" + +#include +#include + namespace triton{ namespace ir{ @@ -42,14 +44,14 @@ struct segment { class liveness { private: - typedef std::map intervals_map_t; + typedef llvm::MapVector intervals_map_t; public: // constructor liveness(layouts *l): layouts_(l){ } // accessors const intervals_map_t& get() const { return intervals_; } - segment get(shared_layout* v) const { return intervals_.at(v); } + segment get(shared_layout* v) const { return intervals_.lookup(v); } // run void run(ir::module &mod); diff --git a/include/triton/tools/graph.h b/include/triton/tools/graph.h index c2ba8d854..69afd5bb3 100644 --- a/include/triton/tools/graph.h +++ b/include/triton/tools/graph.h @@ -3,8 +3,9 @@ #ifndef _TRITON_TOOLS_THREAD_GRAPH_H_ #define _TRITON_TOOLS_THREAD_GRAPH_H_ +#include "llvm/ADT/SetVector.h" + #include -#include #include #include @@ -13,21 +14,21 @@ namespace tools{ template class graph { - typedef std::map> edges_t; + typedef std::map> edges_t; public: typedef std::map> cmap_t; typedef std::map nmap_t; private: - void connected_components_impl(node_t x, std::set &nodes, + void connected_components_impl(node_t x, llvm::SetVector &nodes, nmap_t* nmap, cmap_t* cmap, int id) const { if(nmap) (*nmap)[x] = id; if(cmap) (*cmap)[id].push_back(x); - if(nodes.find(x) != nodes.end()) { - nodes.erase(x); + if (nodes.count(x)) { + nodes.remove(x); for(const node_t &y: edges_.at(x)) connected_components_impl(y, nodes, nmap, cmap, id); } @@ -39,7 +40,7 @@ public: cmap->clear(); if(nmap) nmap->clear(); - std::set nodes = nodes_; + llvm::SetVector nodes = nodes_; unsigned id = 0; while(!nodes.empty()){ connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); @@ -59,7 +60,7 @@ public: } private: - std::set nodes_; + llvm::SetVector nodes_; edges_t edges_; }; diff --git a/python/triton/tools/compare_asm.py b/python/triton/tools/compare_asm.py new file mode 100644 index 000000000..e612022bd --- /dev/null +++ b/python/triton/tools/compare_asm.py @@ -0,0 +1,76 @@ +''' +Compare cached triton kernels in 2 directories. + +example: +python compare_asm.py --dir0=triton-works/ --dir1=triton-fails/ --asm=ttir \ + --diff-out0=diff-works.ll --diff-out1=diff-fails.ll +''' +import argparse +import os +import pickle + +parser = argparse.ArgumentParser(description="unpickle") +parser.add_argument('--dir0', dest='dir0', required=True, + help="Triton cache dir 0") +parser.add_argument('--dir1', dest='dir1', required=True, + help="Triton cache dir 1") +parser.add_argument('--asm', dest='asm', + choices=['ttir', 'llir', 'ptx', 'cubin'], required=True) +parser.add_argument('--early-stop', dest='early_stop', action='store_true', + help="Stop after first diff") +parser.set_defaults(early_stop=True) +parser.add_argument('--diff-out0', dest='diff_out0', required=True, + help="output file path for kernels in dir0") +parser.add_argument('--diff-out1', dest='diff_out1', required=True, + help="output file path for kernels in dir1") +args = parser.parse_args() +dir0 = args.dir0 +dir1 = args.dir1 +asm = args.asm + +dir0_files = {} +dir1_files = {} +for root, _, files in os.walk(dir0): + for file in files: + if not file.endswith('.lock'): + path = os.path.join(root, file) + with open(path, 'rb') as f: + loaded_file = pickle.load(f) + bin = loaded_file['binary'] + key = loaded_file['key'] + info = key.split('-')[-3:] # num_warps, num_stages, signature + dict_key = bin.name + '-'.join(info) + dir0_files[dict_key] = bin.asm + +for root, _, files in os.walk(dir1): + for file in files: + if not file.endswith('.lock'): + path = os.path.join(root, file) + with open(path, 'rb') as f: + loaded_file = pickle.load(f) + bin = loaded_file['binary'] + key = loaded_file['key'] + info = key.split('-')[-3:] # num_warps, num_stages, signature + dict_key = bin.name + '-'.join(info) + dir1_files[dict_key] = bin.asm + +diff_keys = [] +for key in dir0_files: + asm0 = dir0_files[key] + if key not in dir1_files: + continue + asm1 = dir1_files[key] + if asm0[asm] != asm1[asm]: + diff_keys.append(key) + +if args.early_stops: + diff_keys = diff_keys[:1] +if diff_keys: + with open(args.diff_out0, 'w') as f0, open(args.diff_out1, 'w') as f1: + for key in diff_keys: + f0.write(f'{asm} mismatch at {key}') + f0.write(dir0_files[key][asm]) + f0.write('\n') + f1.write(f'{asm} mismatch at {key}') + f1.write(dir1_files[key][asm]) + f1.write('\n') From 96bff90471343ed01dd94effe91ecdaaa7a3f36a Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 24 May 2022 12:08:49 -0700 Subject: [PATCH 115/215] [FRONTEND] faster jit function launch (#523) With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us. --- python/triton/code_gen.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 82ace0105..619e3109e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -23,12 +23,17 @@ import triton import triton._C.libtriton.triton as _triton from .tools.disasm import extract +try: + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +except ImportError: + get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + def current_cuda_stream(device_idx=0): # Torch's torch.cuda.current_stream() is slow. We provide this # function to give the user an opportunity to monkey-patch their # own faster current stream lookup. - return torch.cuda.current_stream().cuda_stream + return get_cuda_stream(device_idx) def mangle_ty(ty): @@ -910,6 +915,7 @@ class Kernel: def __init__(self, fn): self.fn = fn + self.cache_key = {} def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] @@ -951,12 +957,11 @@ class Kernel: # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() - torch.cuda.set_device(device) - # query compute capability - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - cache_key = self.fn.cache_key + cc - # query current stream + if device not in self.cache_key: + cc = torch.cuda.get_device_capability(device) + cc = str(cc[0]) + '-' + str(cc[1]) + self.cache_key[device] = self.fn.cache_key + cc + cache_key = self.cache_key[device] stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, From 011bc83c1b3240016954f38dca7b11809617605b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 24 May 2022 13:20:10 -0700 Subject: [PATCH 116/215] [FRONTEND] For loops now promote initial value (#524) --- python/triton/code_gen.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 619e3109e..d2c834f40 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -584,13 +584,17 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) return + # create nodes st_target = ast.Name(id=node.target.id, ctx=ast.Store()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) + # init node init_node = ast.Assign(targets=[st_target], value=arg_0) + + # step node pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) @@ -610,7 +614,17 @@ class CodeGenerator(ast.NodeVisitor): cond = build_cond() return self.builder.cond_br(cond.handle, loop_bb, next_bb) + # init loop induction variable self.visit(init_node) + # promote it to right type + init_val = self.value_constructor.get_value(node.target.id) + promote = lambda a, b: triton.language.semantic.computation_type_impl(a, b, False) + start_ty = triton.language.core._to_tensor(iter_args[0], self.builder).type + stop_ty = triton.language.core._to_tensor(iter_args[1], self.builder).type if len(iter_args) > 1 else None + ty = promote(start_ty, stop_ty) if len(iter_args) > 1 else start_ty + casted = triton.language.semantic.cast(init_val, ty, self.builder) + self.value_constructor.set_value(node.target.id, casted) + # create cond cond = build_cond() self.builder.cond_br(cond.handle, loop_bb, next_bb) self.builder.set_insert_block(loop_bb) From 43fec2adca046d3ec7e438f56f9c32165a53b2b7 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Wed, 25 May 2022 18:26:18 -0400 Subject: [PATCH 117/215] [FRONTEND] Add binding for create_int_to_ptr (#526) --- python/src/triton.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/python/src/triton.cc b/python/src/triton.cc index a1cf7e54e..ac2bedebf 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -840,6 +840,7 @@ void init_triton_ir(py::module &&m) { .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) .def("create_downcast", &ir::builder::create_downcast, ret::reference) + .def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference) // phi .def("create_phi", &ir::builder::create_phi, ret::reference) // Binary instructions From 0e2883020a8797565221610db8664a1a7a62b561 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 25 May 2022 20:01:19 -0700 Subject: [PATCH 118/215] [BACKEND] Fixed typo in alignment analysis (#528) --- lib/codegen/analysis/align.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 8dabbaf21..37b609228 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -507,6 +507,9 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { unsigned get_max_multiple(int val){ if(val == 0) return 1 << 31; + if(val % 128 == 0) return 128; + if(val % 64 == 0) return 64; + if(val % 32 == 0) return 32; if(val % 16 == 0) return 16; if(val % 8 == 0) return 8; if(val % 4 == 0) return 4; @@ -527,7 +530,7 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)) return populate_starting_multiple_binop(x); if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); + return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_); if(auto *x = dynamic_cast(v)) return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_); if(auto *x = dynamic_cast(v)) From c82a2066841208a048335e019517c2530944e34b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 26 May 2022 17:41:09 -0700 Subject: [PATCH 119/215] [FRONTEND] Better dot error message (#531) --- python/test/unit/language/test_core.py | 2 +- python/triton/language/core.py | 20 ++++++++++++++++++++ python/triton/language/semantic.py | 4 ++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 952922f6b..50bfb9d1c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -910,7 +910,7 @@ def test_arange(start, device='cuda'): def test_masked_load_shared_memory(dtype, device='cuda'): M = 32 N = 32 - K = 8 + K = 16 in1 = torch.rand((M, K), dtype=dtype, device=device) in2 = torch.rand((K, N), dtype=dtype, device=device) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7ef63abba..f81645a36 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -342,6 +342,26 @@ class constexpr: def __bool__(self): return bool(self.value) + def __ge__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value >= other + + def __gt__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value > other + + def __le__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value <= other + + def __lt__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value < other + + def __eq__(self, other): + other = other.value if isinstance(other, constexpr) else other + return self.value == other + def __call__(self, *args, **kwds): return self.value(*args, **kwds) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 10d20fbb3..2af25cbb2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -905,6 +905,10 @@ def dot(lhs: tl.tensor, allow_tf32: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() + assert len(lhs.shape) == 2 and len(rhs.shape) == 2 + assert lhs.shape[-1] == rhs.shape[0] + assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\ + "small blocks not supported!" if lhs.type.scalar.is_int(): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 From 37037bb3beecd2f0011bf1541fcb26d18646447e Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Fri, 27 May 2022 16:51:05 -0400 Subject: [PATCH 120/215] [FRONTEND] Default cache dir to /tmp/triton_$USER (#527) --- python/triton/code_gen.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index d2c834f40..6fbfe9dde 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1150,6 +1150,11 @@ class DependenciesFinder(ast.NodeVisitor): self.ret = hashlib.md5(self.ret).hexdigest() +def default_cache_dir(): + import getpass + return f'/tmp/triton_{getpass.getuser()}' + + class JITFunction: cache_hook = None @@ -1235,7 +1240,7 @@ class JITFunction: hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') + cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) if cache_dir: os.makedirs(cache_dir, exist_ok=True) From 3e7500dfe628abbda134198fb19d9f9ff0158686 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 31 May 2022 17:14:44 -0700 Subject: [PATCH 121/215] [BACKEND] Various bug fixes; making reductions faster (#533) --- include/triton/codegen/analysis/layout.h | 2 +- lib/codegen/analysis/align.cc | 4 +- lib/codegen/analysis/layout.cc | 8 +- lib/codegen/pass.cc | 1 + lib/codegen/selection/generator.cc | 170 +++++++++++++------ lib/codegen/transform/coalesce.cc | 19 +++ python/setup.py | 2 +- python/test/unit/language/test_core.py | 23 +++ python/triton/code_gen.py | 2 +- python/triton/language/core.py | 2 + python/triton/language/semantic.py | 2 +- python/tutorials/03-matrix-multiplication.py | 5 +- 12 files changed, 174 insertions(+), 66 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 28dfad18d..050ac6956 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,7 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } - int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} + int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 37b609228..1c48a4c05 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -319,8 +319,8 @@ std::vector align::populate_max_contiguous_binop(ir::binary_operator* } if(x->is_int_add_sub()){ unsigned lvalue = 1, rvalue = 1; - lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); - rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); + lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); + rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); value = std::max(lvalue, rvalue); } result.push_back(value); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index cec512fec..86473dc54 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -209,14 +209,15 @@ mma_layout::mma_layout(size_t num_warps, rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; contig_per_thread_ = {1, 1}; + order_ = {0, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; + order_ = {1, 0}; // rep_ = {2, 2, 1}; } - order_ = {0, 1}; /* warps per tile */ wpt_ = {1, 1, 1}; @@ -616,8 +617,9 @@ void layouts::run(ir::module &mod) { unsigned axis = red->get_axis(); // shape auto shapes = arg->get_type()->get_block_shapes(); - scanline_layout *layout = get(arg)->to_scanline(); - shapes[axis] = layout->mts(axis); + distributed_layout* layout = dynamic_cast(get(arg)); + shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); + // create layout layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index e2cd6d228..4ba423d20 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -88,6 +88,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index cf51a3b4c..5397ceefe 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -88,6 +88,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define f16_ty builder_->getHalfTy() #define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() +#define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() @@ -736,6 +737,9 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { * \brief Code Generation for a (synchronous) `load` */ void generator::visit_load_inst(ir::load_inst* x){ + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *tid = tgt_->get_local_id(module, *builder_, 0); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); @@ -775,6 +779,9 @@ void generator::visit_load_inst(ir::load_inst* x){ in_off = 0; } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); + // if(!op->get_type()->is_block_ty()){ + // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); + // } Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words @@ -878,6 +885,18 @@ void generator::visit_load_inst(ir::load_inst* x){ Value *_ret = call(inlineAsm, args); + // if(!op->get_type()->is_block_ty()){ + // Value* cond = icmp_eq(tid, i32(0)); + // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); + // Instruction* bar = add_barrier(); + // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); + // builder_->SetInsertPoint(term); + // store(_ret, shptr); + // builder_->SetInsertPoint(bar->getParent()); + // _ret = load(shptr); + // add_barrier(); + // } + // --- // extract and store return values // --- @@ -2033,12 +2052,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { - unsigned cols_per_thread = num_rep_m * 2; + unsigned cols_per_thread = num_rep_n * 2; std::vector idx = { - (m + 0) + (n*2 + 0)*cols_per_thread, - (m + 0) + (n*2 + 1)*cols_per_thread, - (m + 1) + (n*2 + 0)*cols_per_thread, - (m + 1) + (n*2 + 1)*cols_per_thread + (m + 0)*cols_per_thread + (n*2 + 0), + (m + 0)*cols_per_thread + (n*2 + 1), + (m + 1)*cols_per_thread + (n*2 + 0), + (m + 1)*cols_per_thread + (n*2 + 1) }; Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], @@ -2316,62 +2335,93 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ // ir::value *arg = x->get_operand(0); - analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); std::vector shapes = layout->get_shape(); - std::vector order = layout->get_order(); - unsigned mts = layout->mts(order[0]); - unsigned nts = layout->nts(order[0]); - unsigned col_per_thread = shapes[order[0]] / mts; - auto idxs = idxs_.at(arg); - size_t n_elts = idxs.size(); + + Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); + size_t n_bits = sca_ty->getPrimitiveSizeInBits(); + + std::string n_bits_str = std::to_string(n_bits); + std::string cst = (n_bits == 64) ? "l" : "r"; + + FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); + InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true); + FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); + InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true); + + + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value* warp = udiv(thread, i32(32)); + Value* lane = urem(thread, i32(32)); + + unsigned shuffle_width = 0; + unsigned warps_per_inner = 0; + auto arg_vals = vals_.at(arg); + std::vector arg_idxs = idxs_.at(arg); + size_t n_elts = arg_idxs.size(); + unsigned col_per_thread; + Value* warp_i; + Value* warp_j; + if(analysis::scanline_layout* scanline = layout->to_scanline()){ + std::vector order = layout->get_order(); + unsigned mts = scanline->mts(order[0]); + shuffle_width = std::min(mts, 32); + warps_per_inner = std::max(mts/32, 1); + col_per_thread = shapes[order[0]] / mts; + warp_i = udiv(warp, i32(warps_per_inner)); + warp_j = urem(warp, i32(warps_per_inner)); + } + else if(layout->to_mma()){ + shuffle_width = 4; + warps_per_inner = layout->to_mma()->wpt(1); + col_per_thread = 16; + warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; + warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; + } + + // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value* warp = udiv(thread, i32(32)); - Value* lane = urem(thread, i32(32)); - size_t warps_per_inner = std::max(mts/32, 1); - Value* warp_i = udiv(warp, i32(warps_per_inner)); - unsigned row_per_thread = std::max(32/mts, 1); - + // preds + Value* is_lane0 = icmp_eq(lane, i32(0)); + Value* is_warp0 = icmp_eq(warp, i32(0)); + Value* is_thread0 = icmp_eq(thread, i32(0)); + Value* lane_j = urem(lane, i32(shuffle_width)); + Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); + add_barrier(); + // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ Value* acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = vals_[arg][idxs[i*col_per_thread + j]]; + Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; + // acc = (j == 0) ? val : do_acc(acc, val); acc = (j == 0) ? val : do_acc(acc, val); } // reduce within warp - for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) + for(int k = shuffle_width/2 ; k > 0; k >>= 1) acc = do_acc(acc, shfl_sync(acc, k)); - // store warp result in shared memory - Value* ret = acc; - if(mts >= 32){ - add_barrier(); - store(neutral, gep(base, lane)); - add_barrier(); - store(acc, gep(base, warp)); - add_barrier(); - // reduce across warps - Value *cond = icmp_eq(warp, i32(0)); - Instruction *barrier = add_barrier(); - builder_->SetInsertPoint(barrier->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - ret = load(gep(base, thread)); - for(int k = (mts/32)/2; k > 0; k >>= 1){ - Value *current = shfl_sync(ret, k); - ret = do_acc(ret, current); - } - store(ret, gep(base, thread)); - builder_->SetInsertPoint(barrier->getParent()); - ret = load(gep(base, warp)); - } - vals_[x][idxs_[x][i]] = ret; + // store partial result to shared memory + auto x_idxs = idxs_[x][i]; + Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; + Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); } + add_barrier(); + // at this point, partial accumulator synchronized in shared memory + // Just need to reduce `warp_per_inner` numbers in shared memory + for(size_t i = 0; i < n_elts/col_per_thread; i++){ + auto x_idxs = idxs_[x][i]; + Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; + Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); + Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); + for(int k = warps_per_inner/2; k > 0; k >>= 1) + acc = do_acc(acc, shfl_sync(acc, k)); + vals_[x][idxs_[x][i]] = acc; + } + // add_barrier(); } void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { @@ -2471,8 +2521,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); + int cc = tgt_->as_nvidia()->sm(); analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - if(scanline && scanline->get_order()[0] == x->get_axis()) + analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); + bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); + bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); + if(is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2665,12 +2719,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); - analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); + analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(arg)); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - in_vec = in_layout->nts(in_order[0]); + in_vec = in_layout->contig_per_thread(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); @@ -2678,8 +2732,11 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // - int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); - int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); + int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); + int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); + + int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; + int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); int n_shared_0 = std::max(in_vec / out_vec, 1); BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -2700,8 +2757,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); - int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); + int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; + int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ @@ -3026,8 +3083,7 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { else{ /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_12 = udiv(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); + Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); Value *off_lane_m = urem(lane, _16); @@ -3152,7 +3208,9 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ + // i->print(std::cout); visit_value(i); + // std::cout << "done" << std::endl; } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index ae8ce034d..d969139f1 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -52,6 +52,7 @@ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) //} void coalesce::run(ir::module &mod) { + std::set invalidated; ir::builder& builder = mod.get_builder(); // add layout conversion instructions for(ir::function *fn: mod.get_function_list()) @@ -61,12 +62,29 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) + if(op->get_type()->get_tile_rank() == 2) + if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } + // coalesce before copy_to_shared + // It's dirty, but the backend is being rewritten from scratch. :) + if(dynamic_cast(i)) + if(ir::value* op = i->get_operand(0)) + if(op->get_type()->is_block_ty()) + if(op->get_type()->get_tile_rank() == 2) + if(invalidated.find(layout_->get(op)) == invalidated.end()) + if(layout_->get(op)->to_mma()){ + ir::instruction* new_op = ir::cvt_layout_inst::create(op); + builder.set_insert_point(i); + builder.insert(new_op); + op->replace_all_uses_with(new_op); + new_op->replace_uses_of_with(new_op, op); + invalidated.insert(layout_->get(op)); + } // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) @@ -120,6 +138,7 @@ void coalesce::run(ir::module &mod) { } if(in_contig.size() <= 1 || out_contig==in_contig) continue; + std::cout << "3!!" << std::endl; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); x->replace_uses_of_with(val_inst, new_val); diff --git a/python/setup.py b/python/setup.py index 9179baa5b..6a04a4e42 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - # self.debug = True + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 50bfb9d1c..71df6d73b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'): rs = RandomState(17) x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) + x[:] = 1 # numpy result z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result @@ -1132,3 +1133,25 @@ def test_constexpr_shape(): x_tri = to_triton(np.empty((256, ), dtype=np.int32)) kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + +# ------------- +# test if +# ------------- + + +def test_if(): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret): + pid = tl.program_id(0) + cond = tl.load(Cond) + if pid % 2: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device='cuda') + x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') + x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') + ret = torch.empty(1, dtype=torch.float32, device='cuda') + kernel[(1,)](cond, x_true, x_false, ret) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 6fbfe9dde..4f0c75f8d 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -63,7 +63,7 @@ def mangle_ty(ty): def mangle_fn(name, arg_tys, constants): # doesn't mangle ret type, which must be a function of arg tys mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x) mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f81645a36..f0cc02e66 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -32,6 +32,8 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x + elif x is None: + return None assert False, f'cannot convert {x} to tensor' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2af25cbb2..e1c8e6028 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -559,7 +559,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type - if src_ty.is_block(): + if src_ty.is_block() and not dst_ty.is_block(): dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: return input diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f773a3787..912833c52 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -252,6 +252,7 @@ def matmul_kernel( # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @triton.jit def leaky_relu(x): + x = x + 1 return tl.where(x >= 0, x, 0.01 * x) @@ -296,7 +297,7 @@ def matmul(a, b, activation=None): torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b, activation=None) +triton_output = matmul(a, b, activation=leaky_relu) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") @@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output): else: print("❌ Triton and Torch differ") +print(matmul_kernel.cache_key) +exit() # %% # Benchmark # -------------- From efa04cac1ff746701dd1087aca4dae418473413b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 2 Jun 2022 16:57:37 -0700 Subject: [PATCH 122/215] [FRONTEND] A couple of bugfixes (#534) --- python/triton/code_gen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 4f0c75f8d..b64c7eb86 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -63,7 +63,7 @@ def mangle_ty(ty): def mangle_fn(name, arg_tys, constants): # doesn't mangle ret type, which must be a function of arg tys mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x) + key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') @@ -971,6 +971,10 @@ class Kernel: # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() + # torch creates new thread for backward pass that may have uninitlialized context + # no way to know if this function should or shouldn't initialize the cuda context + # so we're being conservative here + torch.cuda.set_device(device) if device not in self.cache_key: cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1]) From a60374a5979f4a68025a5f1fb17d3d1c79332317 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 3 Jun 2022 11:36:06 -0700 Subject: [PATCH 123/215] Revert "[BACKEND] Various bug fixes; making reductions faster (#533)". This is a more stable commit that produce bitwise identical code to earlier versions. Using commits after this one may lead to slightly different numerics --- include/triton/codegen/analysis/layout.h | 2 +- lib/codegen/analysis/align.cc | 4 +- lib/codegen/analysis/layout.cc | 8 +- lib/codegen/pass.cc | 1 - lib/codegen/selection/generator.cc | 170 ++++++------------- lib/codegen/transform/coalesce.cc | 19 --- python/setup.py | 2 +- python/test/unit/language/test_core.py | 23 --- python/triton/language/core.py | 2 - python/triton/language/semantic.py | 2 +- python/tutorials/03-matrix-multiplication.py | 5 +- 11 files changed, 65 insertions(+), 173 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 050ac6956..28dfad18d 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,7 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } - int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} + int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 1c48a4c05..37b609228 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -319,8 +319,8 @@ std::vector align::populate_max_contiguous_binop(ir::binary_operator* } if(x->is_int_add_sub()){ unsigned lvalue = 1, rvalue = 1; - lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); - rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); + lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); + rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); value = std::max(lvalue, rvalue); } result.push_back(value); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 86473dc54..cec512fec 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -209,15 +209,14 @@ mma_layout::mma_layout(size_t num_warps, rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; contig_per_thread_ = {1, 1}; - order_ = {0, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; - order_ = {1, 0}; // rep_ = {2, 2, 1}; } + order_ = {0, 1}; /* warps per tile */ wpt_ = {1, 1, 1}; @@ -617,9 +616,8 @@ void layouts::run(ir::module &mod) { unsigned axis = red->get_axis(); // shape auto shapes = arg->get_type()->get_block_shapes(); - distributed_layout* layout = dynamic_cast(get(arg)); - shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); - + scanline_layout *layout = get(arg)->to_scanline(); + shapes[axis] = layout->mts(axis); // create layout layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 4ba423d20..e2cd6d228 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -88,7 +88,6 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); - // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 5397ceefe..cf51a3b4c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -88,7 +88,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define f16_ty builder_->getHalfTy() #define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() -#define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() @@ -737,9 +736,6 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { * \brief Code Generation for a (synchronous) `load` */ void generator::visit_load_inst(ir::load_inst* x){ - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); - Value *tid = tgt_->get_local_id(module, *builder_, 0); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); @@ -779,9 +775,6 @@ void generator::visit_load_inst(ir::load_inst* x){ in_off = 0; } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); - // if(!op->get_type()->is_block_ty()){ - // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); - // } Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words @@ -885,18 +878,6 @@ void generator::visit_load_inst(ir::load_inst* x){ Value *_ret = call(inlineAsm, args); - // if(!op->get_type()->is_block_ty()){ - // Value* cond = icmp_eq(tid, i32(0)); - // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); - // Instruction* bar = add_barrier(); - // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); - // builder_->SetInsertPoint(term); - // store(_ret, shptr); - // builder_->SetInsertPoint(bar->getParent()); - // _ret = load(shptr); - // add_barrier(); - // } - // --- // extract and store return values // --- @@ -2052,12 +2033,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { - unsigned cols_per_thread = num_rep_n * 2; + unsigned cols_per_thread = num_rep_m * 2; std::vector idx = { - (m + 0)*cols_per_thread + (n*2 + 0), - (m + 0)*cols_per_thread + (n*2 + 1), - (m + 1)*cols_per_thread + (n*2 + 0), - (m + 1)*cols_per_thread + (n*2 + 1) + (m + 0) + (n*2 + 0)*cols_per_thread, + (m + 0) + (n*2 + 1)*cols_per_thread, + (m + 1) + (n*2 + 0)*cols_per_thread, + (m + 1) + (n*2 + 1)*cols_per_thread }; Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], @@ -2335,93 +2316,62 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ // ir::value *arg = x->get_operand(0); - analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); + analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); std::vector shapes = layout->get_shape(); - - Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); - size_t n_bits = sca_ty->getPrimitiveSizeInBits(); - - std::string n_bits_str = std::to_string(n_bits); - std::string cst = (n_bits == 64) ? "l" : "r"; - - FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); - InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true); - FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); - InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true); - - - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value* warp = udiv(thread, i32(32)); - Value* lane = urem(thread, i32(32)); - - unsigned shuffle_width = 0; - unsigned warps_per_inner = 0; - auto arg_vals = vals_.at(arg); - std::vector arg_idxs = idxs_.at(arg); - size_t n_elts = arg_idxs.size(); - unsigned col_per_thread; - Value* warp_i; - Value* warp_j; - if(analysis::scanline_layout* scanline = layout->to_scanline()){ - std::vector order = layout->get_order(); - unsigned mts = scanline->mts(order[0]); - shuffle_width = std::min(mts, 32); - warps_per_inner = std::max(mts/32, 1); - col_per_thread = shapes[order[0]] / mts; - warp_i = udiv(warp, i32(warps_per_inner)); - warp_j = urem(warp, i32(warps_per_inner)); - } - else if(layout->to_mma()){ - shuffle_width = 4; - warps_per_inner = layout->to_mma()->wpt(1); - col_per_thread = 16; - warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; - warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; - } - - // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); + std::vector order = layout->get_order(); + unsigned mts = layout->mts(order[0]); + unsigned nts = layout->nts(order[0]); + unsigned col_per_thread = shapes[order[0]] / mts; + auto idxs = idxs_.at(arg); + size_t n_elts = idxs.size(); // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); - // preds - Value* is_lane0 = icmp_eq(lane, i32(0)); - Value* is_warp0 = icmp_eq(warp, i32(0)); - Value* is_thread0 = icmp_eq(thread, i32(0)); - Value* lane_j = urem(lane, i32(shuffle_width)); - Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); - add_barrier(); - // compute partial sum for each warp, and store to shared memory + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value* warp = udiv(thread, i32(32)); + Value* lane = urem(thread, i32(32)); + size_t warps_per_inner = std::max(mts/32, 1); + Value* warp_i = udiv(warp, i32(warps_per_inner)); + unsigned row_per_thread = std::max(32/mts, 1); + for(size_t i = 0; i < n_elts/col_per_thread; i++){ Value* acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; - // acc = (j == 0) ? val : do_acc(acc, val); + Value* val = vals_[arg][idxs[i*col_per_thread + j]]; acc = (j == 0) ? val : do_acc(acc, val); } // reduce within warp - for(int k = shuffle_width/2 ; k > 0; k >>= 1) + for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) acc = do_acc(acc, shfl_sync(acc, k)); - // store partial result to shared memory - auto x_idxs = idxs_[x][i]; - Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); + // store warp result in shared memory + Value* ret = acc; + if(mts >= 32){ + add_barrier(); + store(neutral, gep(base, lane)); + add_barrier(); + store(acc, gep(base, warp)); + add_barrier(); + // reduce across warps + Value *cond = icmp_eq(warp, i32(0)); + Instruction *barrier = add_barrier(); + builder_->SetInsertPoint(barrier->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); + Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + dummy->removeFromParent(); + builder_->SetInsertPoint(term); + ret = load(gep(base, thread)); + for(int k = (mts/32)/2; k > 0; k >>= 1){ + Value *current = shfl_sync(ret, k); + ret = do_acc(ret, current); + } + store(ret, gep(base, thread)); + builder_->SetInsertPoint(barrier->getParent()); + ret = load(gep(base, warp)); + } + vals_[x][idxs_[x][i]] = ret; } - add_barrier(); - // at this point, partial accumulator synchronized in shared memory - // Just need to reduce `warp_per_inner` numbers in shared memory - for(size_t i = 0; i < n_elts/col_per_thread; i++){ - auto x_idxs = idxs_[x][i]; - Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); - Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); - for(int k = warps_per_inner/2; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); - vals_[x][idxs_[x][i]] = acc; - } - // add_barrier(); } void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { @@ -2521,12 +2471,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - int cc = tgt_->as_nvidia()->sm(); analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); - bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); - bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); - if(is_coalesced_scanline || is_a100_mma) + if(scanline && scanline->get_order()[0] == x->get_axis()) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2719,12 +2665,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); - analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(arg)); + analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - in_vec = in_layout->contig_per_thread(in_order[0]); + in_vec = in_layout->nts(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); @@ -2732,11 +2678,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // - int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); - int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); - - int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; - int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); + int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); + int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); int n_shared_0 = std::max(in_vec / out_vec, 1); BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -2757,8 +2700,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; - int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; + int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); + int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ @@ -3083,7 +3026,8 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { else{ /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); + Value *warp_12 = udiv(warp, i32(layout->wpt(0))); + Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); Value *off_lane_m = urem(lane, _16); @@ -3208,9 +3152,7 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - // i->print(std::cout); visit_value(i); - // std::cout << "done" << std::endl; } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index d969139f1..ae8ce034d 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -52,7 +52,6 @@ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) //} void coalesce::run(ir::module &mod) { - std::set invalidated; ir::builder& builder = mod.get_builder(); // add layout conversion instructions for(ir::function *fn: mod.get_function_list()) @@ -62,29 +61,12 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) - if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } - // coalesce before copy_to_shared - // It's dirty, but the backend is being rewritten from scratch. :) - if(dynamic_cast(i)) - if(ir::value* op = i->get_operand(0)) - if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) - if(invalidated.find(layout_->get(op)) == invalidated.end()) - if(layout_->get(op)->to_mma()){ - ir::instruction* new_op = ir::cvt_layout_inst::create(op); - builder.set_insert_point(i); - builder.insert(new_op); - op->replace_all_uses_with(new_op); - new_op->replace_uses_of_with(new_op, op); - invalidated.insert(layout_->get(op)); - } // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) @@ -138,7 +120,6 @@ void coalesce::run(ir::module &mod) { } if(in_contig.size() <= 1 || out_contig==in_contig) continue; - std::cout << "3!!" << std::endl; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); x->replace_uses_of_with(val_inst, new_val); diff --git a/python/setup.py b/python/setup.py index 6a04a4e42..9179baa5b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - self.debug = True + # self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 71df6d73b..50bfb9d1c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -698,7 +698,6 @@ def test_reduce1d(dtype_str, shape, device='cuda'): rs = RandomState(17) x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - x[:] = 1 # numpy result z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result @@ -1133,25 +1132,3 @@ def test_constexpr_shape(): x_tri = to_triton(np.empty((256, ), dtype=np.int32)) kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) - -# ------------- -# test if -# ------------- - - -def test_if(): - - @triton.jit - def kernel(Cond, XTrue, XFalse, Ret): - pid = tl.program_id(0) - cond = tl.load(Cond) - if pid % 2: - tl.store(Ret, tl.load(XTrue)) - else: - tl.store(Ret, tl.load(XFalse)) - - cond = torch.ones(1, dtype=torch.int32, device='cuda') - x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') - x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') - ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f0cc02e66..f81645a36 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -32,8 +32,6 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x - elif x is None: - return None assert False, f'cannot convert {x} to tensor' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e1c8e6028..2af25cbb2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -559,7 +559,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type - if src_ty.is_block() and not dst_ty.is_block(): + if src_ty.is_block(): dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: return input diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 912833c52..f773a3787 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -252,7 +252,6 @@ def matmul_kernel( # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @triton.jit def leaky_relu(x): - x = x + 1 return tl.where(x >= 0, x, 0.01 * x) @@ -297,7 +296,7 @@ def matmul(a, b, activation=None): torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b, activation=leaky_relu) +triton_output = matmul(a, b, activation=None) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") @@ -306,8 +305,6 @@ if triton.testing.allclose(triton_output, torch_output): else: print("❌ Triton and Torch differ") -print(matmul_kernel.cache_key) -exit() # %% # Benchmark # -------------- From 8876e5320658a83998213ee54ddf70b82a10e363 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 3 Jun 2022 11:38:52 -0700 Subject: [PATCH 124/215] [BACKEND] Restored reduction bugfixes --- include/triton/codegen/analysis/layout.h | 2 +- lib/codegen/analysis/align.cc | 4 +- lib/codegen/analysis/layout.cc | 8 +- lib/codegen/pass.cc | 1 + lib/codegen/selection/generator.cc | 170 +++++++++++++------ lib/codegen/transform/coalesce.cc | 19 +++ python/setup.py | 2 +- python/test/unit/language/test_core.py | 23 +++ python/triton/language/core.py | 2 + python/triton/language/semantic.py | 2 +- python/tutorials/03-matrix-multiplication.py | 5 +- 11 files changed, 173 insertions(+), 65 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 28dfad18d..050ac6956 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -224,7 +224,7 @@ struct scanline_layout: public distributed_layout { int nts(size_t k) { return nts_.at(k); } int contig_per_thread(size_t k) { return nts_.at(k); } - int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);} + int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} public: // micro tile size. The size of a tile held by a thread block. std::vector mts_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 37b609228..1c48a4c05 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -319,8 +319,8 @@ std::vector align::populate_max_contiguous_binop(ir::binary_operator* } if(x->is_int_add_sub()){ unsigned lvalue = 1, rvalue = 1; - lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); - rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); + lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); + rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); value = std::max(lvalue, rvalue); } result.push_back(value); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index cec512fec..86473dc54 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -209,14 +209,15 @@ mma_layout::mma_layout(size_t num_warps, rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; contig_per_thread_ = {1, 1}; + order_ = {0, 1}; } else{ // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; + order_ = {1, 0}; // rep_ = {2, 2, 1}; } - order_ = {0, 1}; /* warps per tile */ wpt_ = {1, 1, 1}; @@ -616,8 +617,9 @@ void layouts::run(ir::module &mod) { unsigned axis = red->get_axis(); // shape auto shapes = arg->get_type()->get_block_shapes(); - scanline_layout *layout = get(arg)->to_scanline(); - shapes[axis] = layout->mts(axis); + distributed_layout* layout = dynamic_cast(get(arg)); + shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); + // create layout layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index e2cd6d228..4ba423d20 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -88,6 +88,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index cf51a3b4c..5397ceefe 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -88,6 +88,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define f16_ty builder_->getHalfTy() #define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() +#define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() @@ -736,6 +737,9 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { * \brief Code Generation for a (synchronous) `load` */ void generator::visit_load_inst(ir::load_inst* x){ + BasicBlock *current = builder_->GetInsertBlock(); + Module *module = current->getModule(); + Value *tid = tgt_->get_local_id(module, *builder_, 0); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); @@ -775,6 +779,9 @@ void generator::visit_load_inst(ir::load_inst* x){ in_off = 0; } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); + // if(!op->get_type()->is_block_ty()){ + // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); + // } Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words @@ -878,6 +885,18 @@ void generator::visit_load_inst(ir::load_inst* x){ Value *_ret = call(inlineAsm, args); + // if(!op->get_type()->is_block_ty()){ + // Value* cond = icmp_eq(tid, i32(0)); + // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); + // Instruction* bar = add_barrier(); + // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); + // builder_->SetInsertPoint(term); + // store(_ret, shptr); + // builder_->SetInsertPoint(bar->getParent()); + // _ret = load(shptr); + // add_barrier(); + // } + // --- // extract and store return values // --- @@ -2033,12 +2052,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { - unsigned cols_per_thread = num_rep_m * 2; + unsigned cols_per_thread = num_rep_n * 2; std::vector idx = { - (m + 0) + (n*2 + 0)*cols_per_thread, - (m + 0) + (n*2 + 1)*cols_per_thread, - (m + 1) + (n*2 + 0)*cols_per_thread, - (m + 1) + (n*2 + 1)*cols_per_thread + (m + 0)*cols_per_thread + (n*2 + 0), + (m + 0)*cols_per_thread + (n*2 + 1), + (m + 1)*cols_per_thread + (n*2 + 0), + (m + 1)*cols_per_thread + (n*2 + 1) }; Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], @@ -2316,62 +2335,93 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ // ir::value *arg = x->get_operand(0); - analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); std::vector shapes = layout->get_shape(); - std::vector order = layout->get_order(); - unsigned mts = layout->mts(order[0]); - unsigned nts = layout->nts(order[0]); - unsigned col_per_thread = shapes[order[0]] / mts; - auto idxs = idxs_.at(arg); - size_t n_elts = idxs.size(); + + Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); + size_t n_bits = sca_ty->getPrimitiveSizeInBits(); + + std::string n_bits_str = std::to_string(n_bits); + std::string cst = (n_bits == 64) ? "l" : "r"; + + FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); + InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true); + FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); + InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true); + + + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value* warp = udiv(thread, i32(32)); + Value* lane = urem(thread, i32(32)); + + unsigned shuffle_width = 0; + unsigned warps_per_inner = 0; + auto arg_vals = vals_.at(arg); + std::vector arg_idxs = idxs_.at(arg); + size_t n_elts = arg_idxs.size(); + unsigned col_per_thread; + Value* warp_i; + Value* warp_j; + if(analysis::scanline_layout* scanline = layout->to_scanline()){ + std::vector order = layout->get_order(); + unsigned mts = scanline->mts(order[0]); + shuffle_width = std::min(mts, 32); + warps_per_inner = std::max(mts/32, 1); + col_per_thread = shapes[order[0]] / mts; + warp_i = udiv(warp, i32(warps_per_inner)); + warp_j = urem(warp, i32(warps_per_inner)); + } + else if(layout->to_mma()){ + shuffle_width = 4; + warps_per_inner = layout->to_mma()->wpt(1); + col_per_thread = 16; + warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; + warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; + } + + // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value* warp = udiv(thread, i32(32)); - Value* lane = urem(thread, i32(32)); - size_t warps_per_inner = std::max(mts/32, 1); - Value* warp_i = udiv(warp, i32(warps_per_inner)); - unsigned row_per_thread = std::max(32/mts, 1); - + // preds + Value* is_lane0 = icmp_eq(lane, i32(0)); + Value* is_warp0 = icmp_eq(warp, i32(0)); + Value* is_thread0 = icmp_eq(thread, i32(0)); + Value* lane_j = urem(lane, i32(shuffle_width)); + Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); + add_barrier(); + // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ Value* acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = vals_[arg][idxs[i*col_per_thread + j]]; + Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; + // acc = (j == 0) ? val : do_acc(acc, val); acc = (j == 0) ? val : do_acc(acc, val); } // reduce within warp - for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) + for(int k = shuffle_width/2 ; k > 0; k >>= 1) acc = do_acc(acc, shfl_sync(acc, k)); - // store warp result in shared memory - Value* ret = acc; - if(mts >= 32){ - add_barrier(); - store(neutral, gep(base, lane)); - add_barrier(); - store(acc, gep(base, warp)); - add_barrier(); - // reduce across warps - Value *cond = icmp_eq(warp, i32(0)); - Instruction *barrier = add_barrier(); - builder_->SetInsertPoint(barrier->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - ret = load(gep(base, thread)); - for(int k = (mts/32)/2; k > 0; k >>= 1){ - Value *current = shfl_sync(ret, k); - ret = do_acc(ret, current); - } - store(ret, gep(base, thread)); - builder_->SetInsertPoint(barrier->getParent()); - ret = load(gep(base, warp)); - } - vals_[x][idxs_[x][i]] = ret; + // store partial result to shared memory + auto x_idxs = idxs_[x][i]; + Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; + Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); } + add_barrier(); + // at this point, partial accumulator synchronized in shared memory + // Just need to reduce `warp_per_inner` numbers in shared memory + for(size_t i = 0; i < n_elts/col_per_thread; i++){ + auto x_idxs = idxs_[x][i]; + Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; + Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); + Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); + for(int k = warps_per_inner/2; k > 0; k >>= 1) + acc = do_acc(acc, shfl_sync(acc, k)); + vals_[x][idxs_[x][i]] = acc; + } + // add_barrier(); } void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { @@ -2471,8 +2521,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); + int cc = tgt_->as_nvidia()->sm(); analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - if(scanline && scanline->get_order()[0] == x->get_axis()) + analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); + bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); + bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); + if(is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2665,12 +2719,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); - analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); + analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(arg)); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - in_vec = in_layout->nts(in_order[0]); + in_vec = in_layout->contig_per_thread(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); @@ -2678,8 +2732,11 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // - int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); - int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); + int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); + int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); + + int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; + int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); int n_shared_0 = std::max(in_vec / out_vec, 1); BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -2700,8 +2757,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); - int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); + int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; + int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ @@ -3026,8 +3083,7 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { else{ /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_12 = udiv(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); + Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); Value *off_lane_m = urem(lane, _16); @@ -3152,7 +3208,9 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ + // i->print(std::cout); visit_value(i); + // std::cout << "done" << std::endl; } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index ae8ce034d..d969139f1 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -52,6 +52,7 @@ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) //} void coalesce::run(ir::module &mod) { + std::set invalidated; ir::builder& builder = mod.get_builder(); // add layout conversion instructions for(ir::function *fn: mod.get_function_list()) @@ -61,12 +62,29 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) + if(op->get_type()->get_tile_rank() == 2) + if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } + // coalesce before copy_to_shared + // It's dirty, but the backend is being rewritten from scratch. :) + if(dynamic_cast(i)) + if(ir::value* op = i->get_operand(0)) + if(op->get_type()->is_block_ty()) + if(op->get_type()->get_tile_rank() == 2) + if(invalidated.find(layout_->get(op)) == invalidated.end()) + if(layout_->get(op)->to_mma()){ + ir::instruction* new_op = ir::cvt_layout_inst::create(op); + builder.set_insert_point(i); + builder.insert(new_op); + op->replace_all_uses_with(new_op); + new_op->replace_uses_of_with(new_op, op); + invalidated.insert(layout_->get(op)); + } // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) @@ -120,6 +138,7 @@ void coalesce::run(ir::module &mod) { } if(in_contig.size() <= 1 || out_contig==in_contig) continue; + std::cout << "3!!" << std::endl; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); x->replace_uses_of_with(val_inst, new_val); diff --git a/python/setup.py b/python/setup.py index 9179baa5b..6a04a4e42 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - # self.debug = True + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 50bfb9d1c..71df6d73b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'): rs = RandomState(17) x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) + x[:] = 1 # numpy result z_ref = np.sum(x).astype(getattr(np, dtype_str)) # triton result @@ -1132,3 +1133,25 @@ def test_constexpr_shape(): x_tri = to_triton(np.empty((256, ), dtype=np.int32)) kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + +# ------------- +# test if +# ------------- + + +def test_if(): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret): + pid = tl.program_id(0) + cond = tl.load(Cond) + if pid % 2: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device='cuda') + x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') + x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') + ret = torch.empty(1, dtype=torch.float32, device='cuda') + kernel[(1,)](cond, x_true, x_false, ret) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f81645a36..f0cc02e66 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -32,6 +32,8 @@ def _to_tensor(x, builder): return _to_tensor(x.value, builder) elif isinstance(x, tensor): return x + elif x is None: + return None assert False, f'cannot convert {x} to tensor' diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2af25cbb2..e1c8e6028 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -559,7 +559,7 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type - if src_ty.is_block(): + if src_ty.is_block() and not dst_ty.is_block(): dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: return input diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f773a3787..912833c52 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -252,6 +252,7 @@ def matmul_kernel( # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @triton.jit def leaky_relu(x): + x = x + 1 return tl.where(x >= 0, x, 0.01 * x) @@ -296,7 +297,7 @@ def matmul(a, b, activation=None): torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b, activation=None) +triton_output = matmul(a, b, activation=leaky_relu) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") @@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output): else: print("❌ Triton and Torch differ") +print(matmul_kernel.cache_key) +exit() # %% # Benchmark # -------------- From 801c8a4c924f58bf15e9ad8d7270f00e4158478a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 5 Jun 2022 12:32:07 -0700 Subject: [PATCH 125/215] [TUTORIALS] Fixed typo --- python/tutorials/03-matrix-multiplication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 912833c52..2dfa98a42 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -297,7 +297,7 @@ def matmul(a, b, activation=None): torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b, activation=leaky_relu) +triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") From 751e325d2edbfce723faa2e1bbf254f6e6ef216f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 5 Jun 2022 13:32:35 -0700 Subject: [PATCH 126/215] [TUTORIALS] Fixed typo --- python/tutorials/03-matrix-multiplication.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 2dfa98a42..39bf8c46a 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -306,8 +306,6 @@ if triton.testing.allclose(triton_output, torch_output): else: print("❌ Triton and Torch differ") -print(matmul_kernel.cache_key) -exit() # %% # Benchmark # -------------- From f13cbaab9fc81fa55a799f9cc76b0dab680f80d7 Mon Sep 17 00:00:00 2001 From: TC <93944281+tomconerlyanth@users.noreply.github.com> Date: Mon, 6 Jun 2022 14:37:08 -0400 Subject: [PATCH 127/215] [FRONTEND] assert that num_warps is a power of 2 (#539) --- python/test/unit/language/test_core.py | 14 ++++++++++++++ python/triton/code_gen.py | 1 + 2 files changed, 15 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 71df6d73b..b2b2cdeb1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1155,3 +1155,17 @@ def test_if(): x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda') kernel[(1,)](cond, x_true, x_false, ret) + + +def test_num_warps_pow2(): + dst = torch.empty(128, device='cuda') + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1,)](dst=dst, num_warps=3) + _kernel[(1,)](dst=dst, num_warps=1) + _kernel[(1,)](dst=dst, num_warps=2) + _kernel[(1,)](dst=dst, num_warps=4) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b64c7eb86..01cd1b5ed 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -954,6 +954,7 @@ class Kernel: return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"{num_warps=} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} wargs = list(wargs) From 2cdc6d35c4182074ca0c8ce2dc387b6e9c0ca7a6 Mon Sep 17 00:00:00 2001 From: Mengchi Zhang Date: Mon, 6 Jun 2022 12:48:23 -0700 Subject: [PATCH 128/215] [FRONTEND] Give col_per_thread an initial value to make the compiler happy (#535) Signed-off-by: Mengchi Zhang --- lib/codegen/selection/generator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 5397ceefe..a04949dc5 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2359,7 +2359,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function arg_idxs = idxs_.at(arg); size_t n_elts = arg_idxs.size(); - unsigned col_per_thread; + unsigned col_per_thread = 0; Value* warp_i; Value* warp_j; if(analysis::scanline_layout* scanline = layout->to_scanline()){ From 38573d1261945602d4f84d665e95740e2c6e58cd Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 7 Jun 2022 18:37:12 -0700 Subject: [PATCH 129/215] [FRONTEND] Return allocated registers and spilled registers for users (#541) --- python/src/triton.cc | 17 ++++++++++------- python/triton/code_gen.py | 12 +++++++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index ac2bedebf..7f9e7e752 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -418,7 +418,7 @@ typedef std::map asm_map_t; // --------------------------------------- // CUDA -std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ +std::tuple cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ // load assembly std::string assembly; if(asm_map.find("cubin") != asm_map.end()) @@ -430,24 +430,27 @@ std::tuple cu_load_binary(const std::string& name, asm_map_t CUmodule mod; drv::dispatch::cuModuleLoadData(&mod, assembly.c_str()); drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); + drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; // set dynamic shared memory if necessary int shared_optin; drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); if(n_shared_bytes > 49152 && shared_optin > 49152){ drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); int shared_total, shared_static; - int n_spills, n_reg; drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev); drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); - drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); } - return std::make_tuple((uint64_t)mod, (uint64_t)fun); + return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); } // ROCM -std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ +std::tuple hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::bytes _assembly = asm_map["hsaco"]; std::string assembly = py::cast(_assembly); // HSA-CO -> hipModule @@ -456,7 +459,7 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ hipFunction_t fun; drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str()); // record asm - return std::make_tuple((uint64_t)mod, (uint64_t)fun); + return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0); } // --------------------------------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 01cd1b5ed..60feb1740 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -769,16 +769,18 @@ class Binary: class LoadedBinary: def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) + module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, + device) self.bin = bin self.asm = bin.asm self.sass = '' self.module = module self.kernel = kernel + self.n_regs = n_regs + self.n_spills = n_spills self.device = device self.shared_mem = bin.shared_mem From 7094657aa993aa9c29d80a44446ff20956ec52e7 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 13 Jun 2022 15:52:37 -0700 Subject: [PATCH 130/215] [FRONTEND] fix bool conversion of floating types (#545) --- python/test/unit/language/test_core.py | 10 +++++++--- python/triton/language/semantic.py | 9 +++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b2b2cdeb1..c76cbbd95 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -557,6 +557,7 @@ def test_atomic_cas(): ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), ('float32', 'int32', True), + ('float32', 'int1', False), ] + [ (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] ] + [ @@ -565,6 +566,8 @@ def test_atomic_cas(): def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. x0 = 43 if dtype_x in int_dtypes else 43.5 + if dtype_x in float_dtypes and dtype_z == 'int1': + x0 = 0.5 if dtype_x.startswith('bfloat'): x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) else: @@ -578,11 +581,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z = x.to(Z.dtype.element_ty, bitcast=BITCAST) tl.store(Z, z) + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) else: - z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device) + z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) # torch result if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): @@ -591,9 +595,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): assert z_tri == z_ref else: if bitcast: - z_ref = x.view(getattr(np, dtype_z)) + z_ref = x.view(getattr(np, dtype_z_np)) else: - z_ref = x.astype(getattr(np, dtype_z)) + z_ref = x.astype(getattr(np, dtype_z_np)) assert to_numpy(z_tri) == z_ref diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e1c8e6028..753944285 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -565,7 +565,6 @@ def cast(input: tl.tensor, return input src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - # bf16 <=> (not fp32) if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): @@ -601,9 +600,7 @@ def cast(input: tl.tensor, if src_sca_ty.is_floating() and dst_sca_ty.is_int(): # TODO: is this correct? if dst_sca_ty.is_bool(): - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return not_equal(input, tl._to_tensor(0, builder), builder) else: return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), @@ -735,8 +732,8 @@ def store(ptr: tl.tensor, elt_ty = ptr_ty.element_ty # treat bool* as tl.int8* if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + elt_ty_ptr = tl.int8 + ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) # cast to target data-type From 58c8889235e343066d48570e6b59c5383bbe7e6e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 13 Jun 2022 16:21:10 -0700 Subject: [PATCH 131/215] [FRONTEND] Fix scanline layout (#548) --- lib/codegen/selection/generator.cc | 15 ++++++--------- python/test/unit/language/test_core.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a04949dc5..f88ecf833 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -3113,24 +3113,21 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - Value *warp_size = i32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = urem(u_thread_id_0, warp_size); - Value *u_warp_id = udiv(u_thread_id_0, warp_size); - + Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0); auto order = layout->get_order(); const auto& shape = layout->get_shape(); - Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id); // Delinearize size_t dim = shape.size(); std::vector thread_id(dim); for(unsigned k = 0; k < dim - 1; k++){ Constant *dim_k = i32(layout->mts(order[k])); - Value *rem = urem(full_thread_id, dim_k); - full_thread_id = udiv(full_thread_id, dim_k); + Value *rem = urem(u_thread_id, dim_k); + u_thread_id = udiv(u_thread_id, dim_k); thread_id[order[k]] = rem; } - thread_id[order[dim - 1]] = full_thread_id; + Constant *dim_k = i32(layout->mts(order[dim - 1])); + thread_id[order[dim - 1]] = urem(u_thread_id, dim_k); + // Create axes for(unsigned k = 0; k < dim; k++) { int nts = layout->nts(k); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c76cbbd95..6ea3ebc9d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -691,7 +691,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("dtype_str, shape", [(dtype, shape) for dtype in dtypes - for shape in [128, 512]]) + for shape in [32, 64, 128, 512]]) def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel From 93209c07e073133617248743c36787c4afe2e1ee Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 13 Jun 2022 16:43:57 -0700 Subject: [PATCH 132/215] [BACKEND][CODEGEN] Fix reduce uint (#547) --- include/triton/ir/instructions.h | 2 +- lib/codegen/selection/generator.cc | 12 ++++-- python/src/triton.cc | 2 + python/test/unit/language/test_core.py | 60 +++++++++++++++++--------- python/triton/language/core.py | 3 ++ python/triton/language/semantic.py | 7 +++ 6 files changed, 61 insertions(+), 25 deletions(-) diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 9f4e18da8..ee7897e03 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -913,7 +913,7 @@ public: class reduce_inst: public builtin_inst { public: enum op_t{ - ADD, SUB, MAX, MIN, + ADD, SUB, MAX, MIN, UMAX, UMIN, FADD, FSUB, FMAX, FMIN, XOR }; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f88ecf833..53cfb70fc 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -119,6 +119,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) #define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) #define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) +#define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__) +#define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__) #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) @@ -2498,6 +2500,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::SUB: return sub(x, y); case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); + case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); + case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); case ir::reduce_inst::FADD: return fadd(x, y); case ir::reduce_inst::FSUB: return fsub(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); @@ -2510,9 +2514,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { Value *neutral; switch(op) { case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; diff --git a/python/src/triton.cc b/python/src/triton.cc index 7f9e7e752..7ebd6b9b9 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -571,6 +571,8 @@ void init_triton_ir(py::module &&m) { .value("FADD", ir::reduce_inst::FADD) .value("MIN", ir::reduce_inst::MIN) .value("MAX", ir::reduce_inst::MAX) + .value("UMIN", ir::reduce_inst::UMIN) + .value("UMAX", ir::reduce_inst::UMAX) .value("FMIN", ir::reduce_inst::FMIN) .value("FMAX", ir::reduce_inst::FMAX) .value("XOR", ir::reduce_inst::XOR); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6ea3ebc9d..348672822 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -688,60 +688,78 @@ def test_f16_to_f8_rounding(): # --------------- -@pytest.mark.parametrize("dtype_str, shape", - [(dtype, shape) +@pytest.mark.parametrize("op, dtype_str, shape", + [(op, dtype, shape) + for op in ['min', 'max', 'sum'] for dtype in dtypes for shape in [32, 64, 128, 512]]) -def test_reduce1d(dtype_str, shape, device='cuda'): +def test_reduce1d(op, dtype_str, shape, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - tl.store(Z, tl.sum(x, axis=0)) + tl.store(Z, GENERATE_TEST_HERE) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'}) + # input rs = RandomState(17) + # limit the range of integers so that the sum does not overflow x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - x[:] = 1 - # numpy result - z_ref = np.sum(x).astype(getattr(np, dtype_str)) - # triton result x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + # numpy result + z_ref = numpy_op(x).astype(getattr(np, dtype_str)) + # triton result z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) kernel[(1,)](x_tri, z_tri, BLOCK=shape) # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + if op == 'sum': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + else: + np.testing.assert_equal(z_ref, to_numpy(z_tri)) reduce_configs1 = [ - (dtype, (1, 1024), axis) for dtype in ['float32', 'uint32'] + (op, dtype, (1, 1024), axis) for dtype in dtypes + for op in ['min', 'max', 'sum'] for axis in [1] ] reduce_configs2 = [ - ('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] + (op, 'float32', shape, 1) + for op in ['min', 'max', 'sum'] + for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] ] -@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2) -def test_reduce2d(dtype_str, shape, axis, device='cuda'): +@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) +def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) - z = tl.sum(x, axis=AXIS) + z = GENERATE_TEST_HERE tl.store(Z + range_m, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input - x = numpy_random(shape, dtype_str=dtype_str) - # triton result + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x) - z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) - # numpy reference result - z_ref = np.sum(x, axis=axis).astype(x.dtype) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + # numpy result + z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str)) + # triton result + z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) + binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + if op == 'sum': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + else: + np.testing.assert_equal(z_ref, to_numpy(z_tri)) # --------------- # test permute diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f0cc02e66..fa6f190e3 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -136,6 +136,9 @@ class dtype: def is_int_signed(self): return self.name in dtype.SINT_TYPES + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + def is_int(self): return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 753944285..e57faa5ec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -959,6 +959,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: input = cast(input, tl.int32, builder) + # choose the right unsigned operation + if scalar_ty.is_int_unsigned(): + if INT_OP is ir.REDUCE_OP.MIN: + INT_OP = ir.REDUCE_OP.UMIN + elif INT_OP is ir.REDUCE_OP.MAX: + INT_OP = ir.REDUCE_OP.UMAX + # get result type shape = input.type.shape ret_shape = [] From 8ce2c12e33e3bc4a641a1309644fca67c7acf9f3 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson <3121668+madeleineth@users.noreply.github.com> Date: Mon, 13 Jun 2022 19:37:52 -0700 Subject: [PATCH 133/215] [PYTHON] move ephemeral files to homedir (#549) This prevents potential conflicts with other users on shared machines. --- python/setup.py | 3 ++- python/triton/code_gen.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/setup.py b/python/setup.py index 6a04a4e42..af1fa3068 100644 --- a/python/setup.py +++ b/python/setup.py @@ -27,10 +27,11 @@ def get_llvm(): return '', '' # download if nothing is installed name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04' - dir = '/tmp' + dir = os.path.join(os.environ["HOME"], ".triton", "llvm") llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name) llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name) if not os.path.exists(llvm_library_dir): + os.makedirs(dir, exist_ok=True) try: shutil.rmtree(os.path.join(dir, name)) except Exception: diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 60feb1740..90a031a30 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1158,8 +1158,7 @@ class DependenciesFinder(ast.NodeVisitor): def default_cache_dir(): - import getpass - return f'/tmp/triton_{getpass.getuser()}' + return os.path.join(os.environ["HOME"], ".triton", "cache") class JITFunction: From 6b9756532fda4c629fb21a0311c567454adc2e23 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 15 Jun 2022 13:13:20 -0700 Subject: [PATCH 134/215] [BACKEND] Remove print in coalesce.cc (#551) --- lib/codegen/transform/coalesce.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index d969139f1..8092ac527 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -138,7 +138,6 @@ void coalesce::run(ir::module &mod) { } if(in_contig.size() <= 1 || out_contig==in_contig) continue; - std::cout << "3!!" << std::endl; builder.set_insert_point_after(val_inst); auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); x->replace_uses_of_with(val_inst, new_val); From b5e728cb14eb0ed8610fbc012162e6c9c8654f70 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 15 Jun 2022 13:55:20 -0700 Subject: [PATCH 135/215] Add argmin argmax (#552) --- include/triton/codegen/analysis/layout.h | 22 +- include/triton/codegen/selection/generator.h | 14 +- include/triton/ir/instructions.h | 13 +- lib/codegen/analysis/layout.cc | 66 ++++-- lib/codegen/selection/generator.cc | 199 +++++++++++++------ lib/codegen/transform/membar.cc | 9 +- python/src/functions.h | 26 ++- python/src/triton.cc | 6 + python/test/unit/language/test_core.py | 57 ++++-- python/triton/language/core.py | 14 ++ python/triton/language/semantic.py | 20 +- 11 files changed, 345 insertions(+), 101 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 050ac6956..99481f694 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -308,13 +308,20 @@ private: void create(size_t id, const std::vector& values); -public: + void create_tmp_layout(size_t id, data_layout* arg, + const std::vector& axes, + const std::vector& shape, + ir::instruction* i, + bool is_index = false); + + public: // constructor layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); // accessors unsigned layout_of(ir::value *value) const { return groups_.at(value); } bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } + bool has(size_t id) { return layouts_.find(id) != layouts_.end(); } const std::vector& values_of(unsigned id) const { return values_.at(id); } size_t num_layouts() const { return values_.size();} data_layout* get(size_t id) { return layouts_.at(id); } @@ -322,7 +329,19 @@ public: std::map &get_all() { return layouts_; } bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } int tmp(ir::value* i) { return tmp_.at(i);} + int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); } + int tmp_index(ir::value* i) { return tmp_index_.at(i);} void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } + + // layout checkers + bool is_scanline(ir::instruction* i); + + bool is_coalesced_scanline(ir::instruction* i); + + bool is_mma(ir::instruction* i); + + bool is_a100_mma(ir::instruction* i); + // execution void run(ir::module &mod); @@ -336,6 +355,7 @@ private: std::map> values_; std::map layouts_; std::map tmp_; + std::map tmp_index_; }; } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 945b9b074..b408a46ca 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -118,8 +118,15 @@ private: llvm::Attribute cvt(ir::attribute attr); void packed_type(ir::value* i); void forward_declare(ir::function* fn); + Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty); -public: + private: + typedef std::function &acc, std::function load_value_fn, + std::function load_index_fn, bool is_first)> + acc_fn_t; + + public: generator(analysis::axes *a_axes, analysis::layouts *layouts, analysis::align *alignment, @@ -176,9 +183,8 @@ public: void visit_trans_inst(ir::trans_inst*); void visit_sqrt_inst(ir::sqrt_inst*); Value* shfl_sync(Value* acc, int32_t i); - void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); - void visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral); - void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); + void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); + void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); void visit_layout_convert(ir::value *out, ir::value *in); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ee7897e03..734ea2b42 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst { public: enum op_t{ ADD, SUB, MAX, MIN, UMAX, UMIN, + ARGMAX, ARGMIN, ARGUMAX, ARGUMIN, FADD, FSUB, FMAX, FMIN, + ARGFMAX, ARGFMIN, XOR }; @@ -932,12 +934,19 @@ public: static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } op_t get_op() const { return op_; } + bool with_index() const { + return with_index_ops_.find(op_) != with_index_ops_.end(); + } private: - unsigned axis_; - op_t op_; + const static inline std::set with_index_ops_ = { + op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX, + op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN}; + unsigned axis_; + op_t op_; }; + class select_inst: public builtin_inst { private: select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 86473dc54..a19be19ef 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector& values) { } } +// layout checkers +bool layouts::is_scanline(ir::instruction *i) { + return this->get(i->get_operand(0))->to_scanline() != nullptr; +} + +bool layouts::is_coalesced_scanline(ir::instruction *i) { + if (auto *red = dynamic_cast(i)) { + auto *scanline = this->get(i->get_operand(0))->to_scanline(); + return scanline && scanline->get_order()[0] == red->get_axis(); + } + return false; +} + +bool layouts::is_mma(ir::instruction *i) { + return this->get(i->get_operand(0))->to_mma() != nullptr; +} + +bool layouts::is_a100_mma(ir::instruction *i) { + if (auto *red = dynamic_cast(i)) { + return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) && + (red->get_axis() == 1); + } + return false; +} + +void layouts::create_tmp_layout(size_t id, data_layout *arg, + const std::vector &axes, + const std::vector &shape, + ir::instruction *i, bool is_index) { + ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context()) + : i->get_type()->get_scalar_ty(); + layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_); + if (is_index) { + tmp_index_[i] = id; + } else { + tmp_[i] = id; + } +} + void layouts::run(ir::module &mod) { // make graph graph_.clear(); @@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) { // std::cout << "layout: " << std::endl; // i->print(std::cout); if(auto *red = dynamic_cast(i)) { - id++; ir::value *arg = red->get_operand(0); - unsigned axis = red->get_axis(); + distributed_layout *layout = + dynamic_cast(get(arg)); // shape auto shapes = arg->get_type()->get_block_shapes(); - distributed_layout* layout = dynamic_cast(get(arg)); - shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); - + unsigned axis = red->get_axis(); + shapes[axis] = + layout->shape_per_cta(axis) / layout->contig_per_thread(axis); // create layout - layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[red] = id; + id++; + create_tmp_layout(id, layout, axes_->get(arg), shapes, red); + + if (red->with_index()) { + id++; + create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true); + } } if(auto *val = dynamic_cast(i)){ distributed_layout* out_layout = dynamic_cast(get(val)); distributed_layout* in_layout = dynamic_cast(get(i->get_operand(0))); - id++; size_t dim = val->get_type()->get_tile_rank(); ir::type::block_shapes_t shape(dim); for(size_t k = 0; k < dim; k++){ @@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) { int out_vec = out_layout->contig_per_thread(out_ord[0]); int pad = std::max(in_vec, out_vec); shape[out_ord[0]] += pad; - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[val] = id; + id++; + create_tmp_layout(id, out_layout, axes_->get(val), shape, val); } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[atom] = id; + create_tmp_layout(id, nullptr, {}, {1}, atom); } }); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 53cfb70fc..ebd21732b 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) #define fadd(...) builder_->CreateFAdd(__VA_ARGS__) #define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) +#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__) +#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__) #define fmul(...) builder_->CreateFMul(__VA_ARGS__) #define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) #define fsub(...) builder_->CreateFSub(__VA_ARGS__) @@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ /** * \brief Code Generation for `reduce` (ND case) */ -void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ - // +void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){ ir::value *arg = x->get_operand(0); + const auto with_index = x->with_index(); + unsigned axis = x->get_axis(); analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); - std::vector shapes = layout->get_shape(); + const auto &shapes = layout->get_shape(); Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); size_t n_bits = sca_ty->getPrimitiveSizeInBits(); - std::string n_bits_str = std::to_string(n_bits); std::string cst = (n_bits == 64) ? "l" : "r"; @@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::functionget_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); @@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function arg_idxs = idxs_.at(arg); size_t n_elts = arg_idxs.size(); unsigned col_per_thread = 0; - Value* warp_i; - Value* warp_j; - if(analysis::scanline_layout* scanline = layout->to_scanline()){ + Value* warp_j = nullptr; + if (analysis::scanline_layout *scanline = layout->to_scanline()) { std::vector order = layout->get_order(); unsigned mts = scanline->mts(order[0]); shuffle_width = std::min(mts, 32); - warps_per_inner = std::max(mts/32, 1); + warps_per_inner = std::max(mts / 32, 1); col_per_thread = shapes[order[0]] / mts; - warp_i = udiv(warp, i32(warps_per_inner)); warp_j = urem(warp, i32(warps_per_inner)); - } - else if(layout->to_mma()){ - shuffle_width = 4; + } else if (layout->to_mma()) { + shuffle_width = 4; warps_per_inner = layout->to_mma()->wpt(1); col_per_thread = 16; - warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; - } + } + assert(warp_j != nullptr); // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); // - Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); - unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); - Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); + Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)), + cvt(x->get_type()->get_scalar_ty())); + Value *index_base = + with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), + IntegerType::get(*ctx_, 32)) + : nullptr; + // preds Value* is_lane0 = icmp_eq(lane, i32(0)); Value* is_warp0 = icmp_eq(warp, i32(0)); Value* is_thread0 = icmp_eq(thread, i32(0)); Value* lane_j = urem(lane, i32(shuffle_width)); - Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); add_barrier(); // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ - Value* acc; + std::pair acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; - // acc = (j == 0) ? val : do_acc(acc, val); - acc = (j == 0) ? val : do_acc(acc, val); + auto arg_idx = arg_idxs[i*col_per_thread + j]; + bool is_first = j == 0; + do_acc( + acc, [&]() -> Value * { return arg_vals[arg_idx]; }, + [&]() -> Value * { return arg_idx[axis]; }, is_first); } + // reduce within warp - for(int k = shuffle_width/2 ; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); + for(int k = shuffle_width/2 ; k > 0; k >>= 1) { + do_acc( + acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, + [&]() -> Value * { return shfl_sync(acc.second, k); }, false); + } // store partial result to shared memory auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); + if (with_index) { + call(st_shared_index, + {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + } } add_barrier(); // at this point, partial accumulator synchronized in shared memory @@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::functiongetInt32(0) : x_idxs[0]; Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); - Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); - for(int k = warps_per_inner/2; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); - vals_[x][idxs_[x][i]] = acc; + std::pair acc; + acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); + acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true), + gep(index_base, ld_off)}) + : nullptr; + for (int k = warps_per_inner / 2; k > 0; k >>= 1) { + do_acc( + acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, + [&]() -> Value * { return shfl_sync(acc.second, k); }, false); + } + vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; } // add_barrier(); } -void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { + +void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) { ir::value *arg = x->get_operand(0); - Type *ty = cvt(x->get_type()->get_scalar_ty()); unsigned axis = x->get_axis(); + auto with_index = x->with_index(); // reduce within thread - std::map accs; + // index-> + std::map> accs; for(indices_t idx: idxs_.at(arg)){ indices_t pidx = idx; pidx[axis] = i32(0); - Value *current = vals_[arg][idx]; bool is_first = accs.find(pidx) == accs.end(); - accs[pidx] = is_first ? current : do_acc(accs[pidx], current); + do_acc( + accs[pidx], [&]() -> Value * { return vals_[arg][idx]; }, + [&]() -> Value * { return idx[axis]; }, is_first); }; // reduce within blocks - analysis::data_layout* layout = layouts_->get(layouts_->tmp(x)); - Value *base = shared_ptr_.at(layout); - auto shape = layout->get_shape(); - auto order = layout->get_order(); - int space = base->getType()->getPointerAddressSpace(); - Value *ptr = bit_cast(base, ptr_ty(ty, space)); + auto *data_layout = layouts_->get(layouts_->tmp(x)); + auto *data_ptr = + cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty())); + auto *index_ptr = + with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), + IntegerType::get(*ctx_, 32)) + : data_ptr; + + auto shape = data_layout->get_shape(); + auto order = data_layout->get_order(); Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; for(auto& x: accs) { // current element being computed - Value *&acc = x.second; + std::pair acc = x.second; indices_t write_idx = x.first; write_idx[axis] = lane; // shared memory write pointer Value *write_off = shared_off(shape, order, write_idx); - Value *write_ptr = gep(ptr, write_off); + Value *write_ptr = gep(data_ptr, write_off); + Value *index_write_ptr = gep(index_ptr, write_off); // initialize shared memory add_barrier(); - store(acc, write_ptr); + store(acc.first, write_ptr); + if (with_index) { + store(acc.second, index_write_ptr); + } // build result indices_t idx(write_idx.size(), i32(0)); for(size_t i = shape[axis]/2; i > 0; i >>= 1){ @@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function Value * { return load(read_ptr); }, + [&]() -> Value * { return load(index_read_ptr); }, false); add_barrier(); - store(acc, write_ptr); + store(acc.first, write_ptr); + if (with_index) { + store(acc.second, index_write_ptr); + } } } add_barrier(); @@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::functionget_type()->get_scalar_ty()); // accumulation function ir::reduce_inst::op_t op = x->get_op(); - auto do_acc = [&](Value *x, Value *y) -> Value* { + auto do_acc_op = [&](Value *x, Value *y) -> Value* { switch(op){ case ir::reduce_inst::ADD: return add(x, y); case ir::reduce_inst::SUB: return sub(x, y); - case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); - case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); + case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y); + case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y); + case ir::reduce_inst::ARGMAX: return icmp_sge(x, y); + case ir::reduce_inst::ARGMIN: return icmp_sle(x, y); case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); + case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); + case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); case ir::reduce_inst::FADD: return fadd(x, y); case ir::reduce_inst::FSUB: return fsub(x, y); + case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y); + case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); case ir::reduce_inst::XOR: return xor_(x, y); default: throw std::runtime_error("unreachable"); } }; + + auto do_acc = [&](std::pair &acc, + std::function load_value_fn, + std::function load_index_fn, + bool is_first) -> void { + auto *val = load_value_fn(); + if (x->with_index()) { + auto *index = load_index_fn(); + if (is_first) { + acc.first = val; + acc.second = index; + } else { + Value *ret = do_acc_op(acc.first, val); + acc.first = select(ret, acc.first, val); + acc.second = select(ret, acc.second, index); + } + } else { + acc.first = is_first ? val : do_acc_op(acc.first, val); + } + }; + // neutral element Value *neutral; switch(op) { case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; + case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; + case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break; + case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; - case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break; default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - int cc = tgt_->as_nvidia()->sm(); - analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); - bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); - bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); - if(is_coalesced_scanline || is_a100_mma) + bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); + bool is_a100_mma = layouts_->is_a100_mma(x); + if (is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){ fns_[fn] = ret; } +Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout, + Type *ty) { + unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); + Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space)); + return base; +} + void generator::visit_function(ir::function* fn) { idxs_.clear(); vals_.clear(); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 96249bcd5..22fe00fe6 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -60,15 +60,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b continue; analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; + analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr; for(ir::value* b: bs){ if(!b->get_type()->is_block_ty()) continue; analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; + analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr; if(intersect_with(a_layout, b_layout) || intersect_with(a_layout, b_tmp) || + intersect_with(a_layout, b_tmp_index) || intersect_with(a_tmp, b_layout) || - intersect_with(a_tmp, b_tmp)) + intersect_with(a_tmp, b_tmp) || + intersect_with(a_tmp, b_tmp_index) || + intersect_with(a_tmp_index, b_layout) || + intersect_with(a_tmp_index, b_tmp) || + intersect_with(a_tmp_index, b_tmp_index)) ret.insert(b); } } diff --git a/python/src/functions.h b/python/src/functions.h index 19f7e7eb9..d5b6c15ef 100644 --- a/python/src/functions.h +++ b/python/src/functions.h @@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) { return builder->create_sqrt(input); }; -/*---------------------------------------------- - definition of triton.min - ----------------------------------------------*/ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { ir::type *scalar_ty = input->get_type()->get_scalar_ty(); @@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder throw_not_int_or_float(name); } +/*---------------------------------------------- + definition of triton.min + ----------------------------------------------*/ std::string min_docstr = R"pbdoc( Returns the minimum value of `input`. )pbdoc"; @@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) { return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); }; +/*---------------------------------------------- + definition of triton.arg_min + ----------------------------------------------*/ +std::string min_docstr = R"pbdoc( + Returns the minimum value's index of `input`. + )pbdoc"; +ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN); +}; + /*---------------------------------------------- definition of triton.max ----------------------------------------------*/ @@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) { return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); }; +/*---------------------------------------------- + definition of triton.arg_max + ----------------------------------------------*/ +std::string max_docstr = R"pbdoc( + Returns the maximum value's index of `input`. + )pbdoc"; +ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX); +}; + /*---------------------------------------------- definition of triton.sum ----------------------------------------------*/ diff --git a/python/src/triton.cc b/python/src/triton.cc index 7ebd6b9b9..4e1849733 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) { .value("MAX", ir::reduce_inst::MAX) .value("UMIN", ir::reduce_inst::UMIN) .value("UMAX", ir::reduce_inst::UMAX) + .value("ARGMIN", ir::reduce_inst::ARGMIN) + .value("ARGMAX", ir::reduce_inst::ARGMAX) + .value("ARGUMIN", ir::reduce_inst::ARGUMIN) + .value("ARGUMAX", ir::reduce_inst::ARGUMAX) .value("FMIN", ir::reduce_inst::FMIN) .value("FMAX", ir::reduce_inst::FMAX) + .value("ARGFMIN", ir::reduce_inst::ARGFMIN) + .value("ARGFMAX", ir::reduce_inst::ARGFMAX) .value("XOR", ir::reduce_inst::XOR); py::enum_(m, "ATOMIC_OP") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 348672822..f1b4f899f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -690,7 +690,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) - for op in ['min', 'max', 'sum'] + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for dtype in dtypes for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): @@ -707,28 +707,37 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): # limit the range of integers so that the sum does not overflow x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) x_tri = to_triton(x, device=device) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] # numpy result - z_ref = numpy_op(x).astype(getattr(np, dtype_str)) + z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device) kernel[(1,)](x_tri, z_tri, BLOCK=shape) + z_tri = to_numpy(z_tri) # compare if op == 'sum': - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) else: - np.testing.assert_equal(z_ref, to_numpy(z_tri)) + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) reduce_configs1 = [ (op, dtype, (1, 1024), axis) for dtype in dtypes - for op in ['min', 'max', 'sum'] + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] reduce_configs2 = [ - (op, 'float32', shape, 1) - for op in ['min', 'max', 'sum'] + (op, 'float32', shape, axis) + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] + for axis in [0, 1] ] @@ -741,7 +750,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): range_n = tl.arange(0, BLOCK_N) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = GENERATE_TEST_HERE - tl.store(Z + range_m, z) + if AXIS == 1: + tl.store(Z + range_m, z) + else: + tl.store(Z + range_n, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input @@ -749,17 +761,30 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): # limit the range of integers so that the sum does not overflow x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str # numpy result - z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str)) + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), + device=device) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + z_tri = to_numpy(z_tri) # compare if op == 'sum': - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) else: - np.testing.assert_equal(z_ref, to_numpy(z_tri)) + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) # --------------- # test permute diff --git a/python/triton/language/core.py b/python/triton/language/core.py index fa6f190e3..d775abf40 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1000,6 +1000,13 @@ def max(input, axis, _builder=None): return semantic.max(input, axis, _builder) +@builtin +@_add_reduction_docstr("maximum index") +def argmax(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmax(input, axis, _builder) + + @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): @@ -1007,6 +1014,13 @@ def min(input, axis, _builder=None): return semantic.min(input, axis, _builder) +@builtin +@_add_reduction_docstr("minimum index") +def argmin(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmin(input, axis, _builder) + + @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e57faa5ec..15a5cb648 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -961,10 +961,14 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, # choose the right unsigned operation if scalar_ty.is_int_unsigned(): - if INT_OP is ir.REDUCE_OP.MIN: - INT_OP = ir.REDUCE_OP.UMIN - elif INT_OP is ir.REDUCE_OP.MAX: - INT_OP = ir.REDUCE_OP.UMAX + int_op_to_unit = { + ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN, + ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX, + ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN, + ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX, + } + if INT_OP in int_op_to_unit: + INT_OP = int_op_to_unit[INT_OP] # get result type shape = input.type.shape @@ -988,10 +992,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) +def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN) + + def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) +def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX) + + def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) From a428cf0bb2b72b2cdaced264d18ca2bb36939627 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 20 Jun 2022 20:12:09 -0700 Subject: [PATCH 136/215] [FRONTEND] Fix pytorch warning. (#560) UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'). --- python/bench/bench_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index b776b3dbf..30864f391 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -5,7 +5,7 @@ import triton def rounded_linspace(low, high, steps, div): ret = torch.linspace(low, high, steps) - ret = (ret.int() + div - 1) // div * div + ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div ret = torch.unique(ret) return list(map(int, ret)) From b02bac41baceb3959c2da29aed97e6d2563f8da3 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 22 Jun 2022 11:44:35 -0700 Subject: [PATCH 137/215] [CI] Change cache dir (#561) --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 45798e628..f78091568 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -21,7 +21,7 @@ jobs: - name: Clear cache run: | - rm -r /tmp/triton/ + rm -r ~/.triton/ continue-on-error: true - name: Install Triton From d345ddf83782b9abd04399ec3ae63fad6a486dd7 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 22 Jun 2022 17:51:17 -0700 Subject: [PATCH 138/215] [DOCS] Separate atomic cas from other atomic operations since operands are very different (#559) --- docs/python-api/triton.language.rst | 4 ++++ python/triton/language/core.py | 35 +++++++++++++++++++---------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 1f05ce8a6..18bf95be4 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -106,9 +106,13 @@ Atomic Ops :nosignatures: atomic_cas + atomic_xchg atomic_add atomic_max atomic_min + atomic_and + atomic_or + atomic_xor Comparison ops diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d775abf40..1c54ef2c7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -806,6 +806,25 @@ def store(pointer, value, mask=None, _builder=None): # Atomic Memory Operations # ----------------------- +@builtin +def atomic_cas(pointer, cmp, val, _builder=None): + """ + Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to compare-and-swap. + :type pointer: Block of dtype=triton.PointerDType + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=`pointer.dtype.element_ty` + :param val: The values to copy in case the expected value matches the contained value. + :type val: Block of dtype=`pointer.dtype.element_ty` + """ + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) + + def _add_atomic_docstr(name): def _decorator(func): @@ -814,12 +833,12 @@ def _add_atomic_docstr(name): Return the data stored at :code:`pointer` before the atomic operation. - :param pointer: The memory locations to compare-and-swap. + :param pointer: The memory locations to apply {name}. :type pointer: Block of dtype=triton.PointerDType - :param cmp: The values expected to be found in the atomic object - :type cmp: Block of dtype=`pointer.dtype.element_ty` - :param val: The values to copy in case the expected value matches the contained value. + :param val: The values to {name} in the atomic object. :type val: Block of dtype=`pointer.dtype.element_ty` + :param mask: If mask[idx] is false, do not apply {name}. + :type mask: Block of triton.int1, optional """ func.__doc__ = docstr.format(name=name) return func @@ -827,14 +846,6 @@ def _add_atomic_docstr(name): return _decorator -@builtin -@_add_atomic_docstr("compare-and-swap") -def atomic_cas(pointer, cmp, val, _builder=None): - cmp = _to_tensor(cmp, _builder) - val = _to_tensor(val, _builder) - return semantic.atomic_cas(pointer, cmp, val, _builder) - - @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): From 87413bc92522f14da4860adb506a8bc96c5e3a89 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 25 Jun 2022 23:12:03 -0700 Subject: [PATCH 139/215] [BACKEND] Fix layout convert for non-contiguous input (#564) --- lib/codegen/selection/generator.cc | 15 ++++++++++----- python/test/unit/language/test_core.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index ebd21732b..8d95a2790 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2638,8 +2638,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); - auto in_ord = in_layout->get_order(); - auto out_ord = out_layout->get_order(); Value *base; base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); base = bit_cast(base, ptr_ty(ty, 3)); @@ -2656,9 +2654,16 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); } - in_ord = in_layout->to_mma() ? out_ord : in_ord; - out_ord = out_layout->to_mma() ? in_ord : out_ord; - int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); + auto in_ord = + in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order(); + auto out_ord = + out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order(); + // out_ord[0] == 0 or in_order[0] == 0 means the first dimension is + // non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and + // and in_ord[0] are contiguous. + int in_vec = out_ord[0] == 0 ? 1 + : in_ord[0] == 0 ? 1 + : in_layout->contig_per_thread(in_ord[0]); int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); int pad = std::max(in_vec, out_vec); Value *in_ld = i32(shape[in_ord[0]] + pad); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f1b4f899f..a5fb0acba 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -793,8 +793,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) - for dtype in ['float32'] - for shape in [(128, 128)] + for dtype in ['float16', 'float32'] + for shape in [(64, 64), (128, 128)] for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): @@ -812,18 +812,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'): x = numpy_random(shape, dtype_str=dtype_str) # triton result z_tri = to_triton(np.empty_like(x), device=device) + z_tri_contiguous = to_triton(np.empty_like(x), device=device) x_tri = to_triton(x, device=device) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), BLOCK_M=shape[0], BLOCK_N=shape[1]) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), + z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), + BLOCK_M=shape[0], BLOCK_N=shape[1]) # torch result z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) + triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) # parse ptx to make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx # --------------- # test dot From 5b4c8f221e00290ae2250b4985c3aa7190c96b7b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 27 Jun 2022 11:49:19 -0700 Subject: [PATCH 140/215] [BACKEND] Compiler improvements (#557) This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention. --- include/triton/codegen/analysis/layout.h | 5 +- include/triton/codegen/transform/coalesce.h | 3 +- include/triton/codegen/transform/cts.h | 14 +- include/triton/ir/builder.h | 6 +- include/triton/ir/function.h | 2 +- include/triton/ir/instructions.h | 48 ++- include/triton/ir/utils.h | 1 + lib/codegen/analysis/allocation.cc | 4 +- lib/codegen/analysis/layout.cc | 59 ++- lib/codegen/analysis/liveness.cc | 121 ++++-- lib/codegen/analysis/swizzle.cc | 11 +- lib/codegen/pass.cc | 14 +- lib/codegen/selection/generator.cc | 413 ++++++++++++++------ lib/codegen/transform/coalesce.cc | 16 +- lib/codegen/transform/cts.cc | 71 ++-- lib/codegen/transform/peephole.cc | 2 +- lib/ir/basic_block.cc | 5 +- lib/ir/builder.cc | 12 +- lib/ir/instructions.cc | 30 +- lib/ir/utils.cc | 9 + python/test/unit/language/test_core.py | 57 ++- python/triton/code_gen.py | 2 +- python/triton/language/core.py | 8 +- python/triton/language/semantic.py | 55 +-- python/tutorials/06-fused-attention.py | 198 ++++++++++ 25 files changed, 882 insertions(+), 284 deletions(-) create mode 100644 python/tutorials/06-fused-attention.py diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 99481f694..a69687875 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -258,7 +258,8 @@ public: const std::vector& shapes, const std::vector &values_, ir::type *ty, - analysis::align* align, target *tgt); + analysis::align* align, target *tgt, + bool is_tmp = false); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } // accessors size_t get_size() { return size_; } @@ -276,6 +277,7 @@ public: int get_mma_strided() { return mma_strided_; } bool allow_swizzle() const { return allow_swizzle_; } data_layout* get_arg_layout() { return arg_layout_; } + bool is_tmp() const { return is_tmp_; } private: size_t size_; @@ -290,6 +292,7 @@ private: int mma_strided_; bool allow_swizzle_ = true; target *tgt_; + bool is_tmp_; }; diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index 869ca9975..e16ffe5fe 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -32,11 +32,12 @@ private: ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); public: - coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); + coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80); triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); void run(ir::module &mod); private: + bool has_sm80_; analysis::align* align_; analysis::layouts* layout_; }; diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h index 70fbc474b..30b421b52 100644 --- a/include/triton/codegen/transform/cts.h +++ b/include/triton/codegen/transform/cts.h @@ -15,18 +15,26 @@ namespace ir { } namespace codegen{ + +namespace analysis{ +class layouts; +} + namespace transform{ class cts { private: - void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared); + bool is_shmem_op(ir::instruction* i, int op); + bool is_shmem_res(ir::value* i); +void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map& copies); public: - cts(bool use_async = false): use_async_(use_async) {} + cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {} void run(ir::module &mod); private: - bool use_async_; + bool has_sm80_; + analysis::layouts* layouts_; }; } diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0cb622679..74028f822 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -142,9 +142,9 @@ public: value *create_or(value *lhs, value *rhs); // Input/Output value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_store(value *ptr, value *val); + value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_masked_store(value *ptr, value *val, value *mask); + value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction); // Struct instructions value *create_insert_value(value* val, value *elt, size_t idx); value *create_extract_value(value* val, size_t idx); @@ -176,7 +176,7 @@ public: value *create_cos(value* arg); value *create_sin(value* arg); value *create_log(value* arg); - value *create_dot(value *A, value *B, value *C, bool allow_tf32); + value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32); value *create_trans(value *A, const std::vector &perm = {}); value *create_sqrt(value *A); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 4e76e60a4..61ec2a6ae 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -112,7 +112,7 @@ public: static function *create(function_type *ty, linkage_types_t linkage, const std::string &name, module *mod); // blocks - const blocks_t &blocks() { return blocks_; } + blocks_t &blocks() { return blocks_; } const blocks_t &blocks() const { return blocks_; } void insert_block(basic_block* block, basic_block *next = nullptr); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 734ea2b42..402208a8b 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -435,13 +435,31 @@ private: //===----------------------------------------------------------------------===// class io_inst: public instruction { +public: + + enum EVICTION_POLICY : uint32_t { + NORMAL=0, + EVICT_FIRST, + EVICT_LAST, + }; + protected: - io_inst(type *ty, value_id_t id, unsigned num_ops, + io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); + std::string get_eviction_policy_repr() const { + if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; + if (eviction_ == EVICT_LAST) return ".L2::evict_last"; + return ""; + } + public: // accessors value *get_pointer_operand() { return get_operand(0); } + EVICTION_POLICY get_eviction_policy() const { return eviction_; } + +protected: + EVICTION_POLICY eviction_; }; // load @@ -453,14 +471,8 @@ public: CG, }; - enum EVICTION_POLICY : uint32_t { - NORMAL=0, - EVICT_FIRST, - EVICT_LAST, - }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } - EVICTION_POLICY get_eviction_policy() const { return eviction_; } bool get_is_volatile() const { return is_volatile_; } protected: @@ -472,12 +484,6 @@ protected: if (cache_ == CG) return ".cg"; return ""; } - std::string get_eviction_policy_repr() const { - if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; - if (eviction_ == EVICT_LAST) return ".L2::evict_last"; - return ""; - } - EVICTION_POLICY eviction_; CACHE_MODIFIER cache_; std::string get_volatile_repr() { @@ -553,7 +559,7 @@ public: // store class store_inst: public io_inst { protected: - store_inst(value *ptr, value_id_t id, unsigned num_ops, + store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); public: @@ -564,11 +570,11 @@ public: class unmasked_store_inst: public store_inst{ private: std::string repr_impl() const { return "unmasked_store"; } - unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next); + unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next); public: // factory method - static unmasked_store_inst* create(value* ptr, value *v, + static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_store_inst) @@ -578,14 +584,14 @@ public: class masked_store_inst: public store_inst{ private: std::string repr_impl() const { return "masked_store"; } - masked_store_inst(value *ptr, value *v, value *mask, + masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, const std::string &name, instruction *next); public: // accessors value *get_mask_operand() { return get_operand(2); } // factory method - static masked_store_inst* create(value *ptr, value *v, value *mask, + static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_store_inst) @@ -755,6 +761,8 @@ private: class atomic_inst: public io_inst { public: using io_inst::io_inst; + atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next): + io_inst(ty, id, num_ops, NORMAL, name, next) {} }; class atomic_rmw_inst: public atomic_inst { @@ -856,6 +864,8 @@ public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } bool allow_tf32() const { return allow_tf32_; } + bool is_trans_a() const { return AT_ == Trans; } + bool is_trans_b() const { return BT_ == Trans; } public: static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); @@ -872,6 +882,8 @@ private: DataType C_type_ = DataType::FP32; DataType A_type_ = DataType::FP16; DataType B_type_ = DataType::FP16; + TransT AT_; + TransT BT_; }; //class outer_inst: public builtin_inst { diff --git a/include/triton/ir/utils.h b/include/triton/ir/utils.h index 893edd122..1fad79181 100644 --- a/include/triton/ir/utils.h +++ b/include/triton/ir/utils.h @@ -22,6 +22,7 @@ public: }; void for_each_instruction(ir::module& mod, const std::function &fn); +void for_each_instruction_backward(module &mod, const std::function &do_work); void for_each_value(ir::module& mod, const std::function &fn); } diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 3af40c2cc..f842c0f61 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) { } // Save maximum size of induced memory space allocated_size_ = 0; - for(shared_layout* x: V) + for(shared_layout* x: V){ allocated_size_ = std::max(allocated_size_, starts[x] + x->get_size()); + // std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl; + } } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index a19be19ef..69c36f752 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps, order_ = {0, 1}; } else{ - // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; order_ = {1, 0}; - // rep_ = {2, 2, 1}; } /* warps per tile */ @@ -233,24 +231,45 @@ mma_layout::mma_layout(size_t num_warps, }while(wpt_nm1 != wpt_); } else { bool changed = false; - do { - changed = false; - if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) - break; - if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { - if (wpt_[0] < shape_[0] / spw_[0]) { - wpt_[0] *= 2; - changed = true; + // try to have a warp own entire rows of the output + // this makes it easier to fuse multiple mmas by fusing + // registers + bool one_warp_per_row = false; + for(ir::value* v: values) + for(ir::user* u: v->get_users()){ + auto* dot = dynamic_cast(u); + auto* cts = dynamic_cast(u); + if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts) + one_warp_per_row = shape[0] / spw_[0] >= num_warps; + } + // std::cout << one_warp_per_row << std::endl; + + if(one_warp_per_row){ + wpt_[1] = 1; + wpt_[0] = num_warps; + } + else{ + do { + changed = false; + if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) + break; + if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { + if (wpt_[0] < shape_[0] / spw_[0]) { + wpt_[0] *= 2; + changed = true; + } + } else { + if (wpt_[1] < shape_[1] / (spw_[1]*2)) { + wpt_[1] *= 2; + changed = true; + } } - } else { - if (wpt_[1] < shape_[1] / (spw_[1]*2)) { - wpt_[1] *= 2; - changed = true; - } - } - } while (changed); + } while(changed); + } } + // std::cout << wpt_[0] << " " << wpt_[1] << std::endl; + /* shape per block */ shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; } @@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg, const std::vector& shape, const std::vector &values, ir::type *ty, - analysis::align* align, target *tgt) - : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) { + analysis::align* align, target *tgt, bool is_tmp) + : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){ size_ = 0; arg_layout_ = arg; @@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg, ir::instruction *i, bool is_index) { ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context()) : i->get_type()->get_scalar_ty(); - layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_); + layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true); if (is_index) { tmp_index_[i] = id; } else { diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 7beae21a1..535df4eb9 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -14,43 +14,108 @@ namespace analysis{ void liveness::run(ir::module &mod) { intervals_.clear(); - // Assigns index to each instruction - std::map indices; - for(ir::function *fn: mod.get_function_list()){ - slot_index index = 0; - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()){ - index += 1; - indices.insert({instr, index}); + std::map> layouts_map; + for(auto &x: layouts_->get_all()){ + shared_layout* layout = x.second->to_shared(); + if(!layout || layout->is_tmp()) + continue; + for(ir::value* v:layout->get_values()){ + layouts_map[v].insert(layout); } } - // create live intervals + + + std::map> live_in; + while(true){ + bool changed = false; + ir::instruction* last_inst = nullptr; + ir::for_each_instruction_backward(mod, [&](ir::instruction* i){ + // gen + std::set gen; + for(ir::value* v: i->ops()) + for(shared_layout* layout: layouts_map[v]) + gen.insert(layout); + // kill + std::set kill; + for(shared_layout* layout: layouts_map[i]) + kill.insert(layout); + // temporaries are handled separately + if(layouts_->has_tmp(i)){ + gen.insert(layouts_->get(layouts_->tmp(i))->to_shared()); + kill.insert(layouts_->get(layouts_->tmp(i))->to_shared()); + } + if(layouts_->has_tmp_index(i)){ + gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared()); + kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared()); + } + // live-out + std::set live_out; + std::vector succs = {last_inst}; + if(i == i->get_parent()->get_inst_list().back()) + for(ir::basic_block* succ: i->get_parent()->get_successors()) + succs.push_back(succ->get_inst_list().front()); + for(ir::instruction* succ: succs) + for(shared_layout* layout: live_in[succ]) + if(!layout->is_tmp()) + live_out.insert(layout); + + // new sets + std::set live_out_minus_kill; + std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(), + std::inserter(live_out_minus_kill, live_out_minus_kill.end())); + std::set new_live_in; + std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(), + std::inserter(new_live_in, new_live_in.end())); + + changed = changed || (new_live_in != live_in[i]); + live_in[i] = new_live_in; + last_inst = i; + }); + if(!changed) + break; + } + + // ir::for_each_instruction(mod, [&](ir::instruction* i){ + // i->print(std::cout); + // std::cout << " live_in: " << live_in[i].size() << std::endl; + // }); + + + + // Assigns index to each instruction + std::map indices; + slot_index index = 0; + ir::for_each_instruction(mod, [&](ir::instruction* instr){ + index += 1; + indices.insert({instr, index}); + }); + + + for(auto &x: layouts_->get_all()){ + shared_layout* layout = x.second->to_shared(); + if(layout) + intervals_[layout] = segment{INT32_MAX, 0}; + } + + for(auto& x: live_in) + for(shared_layout* layout: x.second) + intervals_[layout].start = std::min(intervals_[layout].start, indices[x.first]); + + for(auto& x: live_in) + for(shared_layout* layout: x.second){ + intervals_[layout].end = std::max(intervals_[layout].end, indices[x.first] + 1); + } + + for(auto &x: layouts_->get_all()) { shared_layout* layout = x.second->to_shared(); if(!layout) continue; - // users - std::set users; - for(ir::value *v: layout->get_values()){ - for(ir::user *u: v->get_users()) - users.insert(u); - } - // compute intervals - unsigned start = INT32_MAX; - for(ir::value *v: layout->get_values()) - if(indices.find(v) != indices.end()) - start = std::min(start, indices.at(v)); - unsigned end = 0; - for(ir::user *u: users) - if(indices.find(u) != indices.end()) - end = std::max(end, indices.at(u)); - if(end == 0) - end = start + 1; - intervals_[layout] = segment{start, end}; + // std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl; } - + } diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 5737f80a0..08843bbf7 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -28,12 +28,15 @@ void swizzle::run(ir::module &) { } auto ord = layout->get_order(); scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); - if(!in_layout) - continue; + int per_phase = 1; int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + if(in_layout) + per_phase = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + else + per_phase = 1; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + per_phase_[layout] = per_phase; max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; if(mma_dot_a) vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); @@ -46,7 +49,7 @@ void swizzle::run(ir::module &) { max_phase_[layout] = 1; vec_[layout] = 1; } else { - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + per_phase_[layout] = per_phase; max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; vec_[layout] = layout->get_mma_vec(); } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 4ba423d20..412e2f4c8 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -31,27 +31,28 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); // optimizations - bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; + bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; codegen::transform::inliner inliner; codegen::analysis::axes axes; - codegen::transform::cts cts(cts_use_async); - codegen::transform::pipeline pipeline(cts_use_async, num_stages); + codegen::transform::pipeline pipeline(has_sm80, num_stages); codegen::transform::disassociate disassociate; codegen::analysis::layouts layouts(&axes, &align, num_warps, target); + codegen::transform::cts cts(&layouts, has_sm80); codegen::analysis::liveness liveness(&layouts); codegen::analysis::swizzle swizzle(&layouts, target); codegen::analysis::allocation allocation(&liveness); codegen::transform::dce dce; codegen::transform::peephole peephole(target, &layouts); - codegen::transform::coalesce coalesce(&align, &layouts); + codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); codegen::transform::prefetch prefetch_s(target); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); // run passes inliner.run(ir); dce.run(ir); + // ir.print(std::cout); peephole.run(ir); dce.run(ir); pipeline.run(ir); @@ -84,10 +85,15 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC axes.run(ir); layouts.run(ir); swizzle.run(ir); + // std::cout << "---" << std::endl; + // ir.print(std::cout); + // std::cout << "---" << std::endl; + // ir.print(std::cout); liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // exit(1); // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8d95a2790..e69b0acee 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){ BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); + Value *lane = urem(tid, i32(32)); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); // compute vector width size_t vec = 1; + bool is_mma_first_row = false; if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); @@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){ max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } - auto layout = layouts_->get(x)->to_scanline(); - if(layout){ - size_t nts = layout->nts(ord[0]); - vec = std::min(nts, aln); - } + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(x)); + assert(layout); + + vec = std::min(layout->contig_per_thread(ord[0]), aln); + // TODO: generalize + is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1)); + if(is_mma_first_row) + vec = std::min(2, aln); } // code generation auto idxs = idxs_.at(x); @@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){ int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); - bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; - has_evict_policy = false; // currently disable until supported in `store` + bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + // has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string // ----- @@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; - if (has_evict_policy) asm_oss << ".L2::cache_hint"; + if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size @@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_oss << "}"; asm_oss << ", [ $" << n_words + 1; // load asm_oss << " + " << in_off << "]"; // constant offset - if (has_evict_policy) asm_oss << ", $" << n_words + 2; + if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; asm_oss << ";"; bool has_other = other && (other != UndefValue::get(other->getType())); std::vector others; @@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if(ConstantInt* cst = dyn_cast(v)) asm_oss << "0x" << std::hex << cst->getSExtValue(); else{ - asm_oss << "$" << n_words + has_evict_policy + 2 + ii; + asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii; others.push_back(v); } asm_oss.flags(flags); @@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); - if (has_evict_policy) + if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- @@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } - if (has_evict_policy) + if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM @@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); - if (has_evict_policy) + if (has_l2_evict_policy) args.push_back(policies_.at(x->get_eviction_policy())); @@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){ // operands ir::value *ptr_op = x->get_pointer_operand(); ir::value *val_op = x->get_value_operand(); + ir::value *msk_op = nullptr; + if(auto* msk_st = dynamic_cast(x)) + msk_op = msk_st->get_mask_operand(); // vector size size_t vec = 1; if(val_op->get_type()->is_block_ty()){ @@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){ max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } - vec = std::min(nts, aln); + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(ptr_op)); + assert(layout); + // vec = std::min(nts, aln); + vec = std::min(layout->contig_per_thread(ord[0]), aln); + // TODO: generalize + bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1)); + if(is_mma_first_row) + vec = std::min(2, aln); } + bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store ty = f16_ty; + if(ty->isIntegerTy(1)) + ty = builder_->getInt8Ty(); for(size_t i = 0; i < idxs.size(); i += vec){ - auto idx = idxs[i]; - // pointer + indices_t idx = idxs[i]; + // pointers Value *ptr = vals_[ptr_op][idx]; - // vectorize - Type *v_ty = vec_ty(ty, vec); - ptr = bit_cast(ptr, v_ty->getPointerTo(1)); - // value - Value* val = UndefValue::get(v_ty); - for(size_t ii = 0; ii < vec; ii++) - val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii); - if(mx){ - Value *msk = vals_[mx->get_mask_operand()][idx]; - Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); - builder_->SetInsertPoint(no_op->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - store(val, ptr); - builder_->SetInsertPoint(no_op); + size_t dtsize = std::max(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8); + GetElementPtrInst *in_gep = dyn_cast(ptr); + size_t in_off; + if(in_gep){ + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; + ptr = cst ? in_gep->getPointerOperand() : in_gep; } - else - store(val, ptr); + else{ + in_off = 0; + } + // mask + Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue(); + size_t nbits = dtsize*8; + // pack sub-words (< 32/64bits) into words + // each load has width min(nbits*vec, 32/64) + // and there are (nbits * vec)/width of them + int max_word_width = std::max(32, nbits); + int tot_width = nbits*vec; + int width = std::min(tot_width, max_word_width); + int n_words = std::max(1, tot_width / width); + // ----- + // create inline asm string + // ----- + std::ostringstream asm_oss; + asm_oss << "@$0"; // predicate + asm_oss << " st.global"; + if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; + if(n_words > 1) + asm_oss << ".v" << n_words; // vector width + asm_oss << ".b" << width; // word size + asm_oss << " [ $1 + " << in_off << "]"; + asm_oss << " , {"; + for(int i = 0; i < n_words; i++){ // return values + if(i > 0) asm_oss << ","; + asm_oss << "$" << 2 + i; + } + asm_oss << "}"; + if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; + asm_oss << ";"; + // ---- + // create inline ASM signature + // --- + Type* val_arg_ty = IntegerType::get(*ctx_, width); + std::vector arg_tys = {pred->getType(), ptr->getType()}; + for(int ii = 0; ii < n_words; ii++) + arg_tys.push_back(val_arg_ty); + if (has_l2_evict_policy) + arg_tys.push_back(i64_ty); + FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false); + // --- + // create inline ASM constraints + // --- + std::string asm_cstrt = "b,l"; + for(int ii = 0; ii < n_words; ii++){ + asm_cstrt += ","; + asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + } + if (has_l2_evict_policy) + asm_cstrt += ",l"; + // --- + // finally call inline ASM + // --- + InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); + std::vector args = {pred, ptr}; + for(unsigned int ii = 0; ii < n_words; ii++){ + size_t n_subw = width / nbits; + Value* curr = UndefValue::get(vec_ty(ty, n_subw)); + for(unsigned int jj = 0; jj < n_subw; jj++){ + Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]]; + if(new_elt->getType()->isIntegerTy(1)) + new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty()); + new_elt = bit_cast(new_elt, ty); + curr = builder_->CreateInsertElement(curr, new_elt, jj); + } + args.push_back(bit_cast(curr, val_arg_ty)); + } + if (has_l2_evict_policy) + args.push_back(policies_.at(x->get_eviction_policy())); + call(_asm, args); } } void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { @@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); for(auto idx: idxs_.at(x)){ Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); + // Value *ex2arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = call(ex2, std::vector{ex2arg}); } } @@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va // order auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); + bool is_a_trans = C->is_trans_a(); + // is_a_trans = false; + if(C->is_trans_a()){ + std::swap(ord_a[0], ord_a[1]); + std::swap(shape_a[0], shape_a[1]); + std::swap(offset_a_m_, offset_a_k_); + } + // std::cout << "visiting" << std::endl; + // if(C->is_trans_b()){ + // std::swap(ord_b[0], ord_b[1]); + // std::swap(shape_b[0], shape_b[1]); + // } // layouts analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); @@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); + + // max_phase_a = 4; + // vec_a = 8; + // std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl; + // std::cout << vec_a << " " << vec_b << std::endl; + /* --------------------------------- */ /* --- pre-compute pointer lanes --- */ /* --------------------------------- */ @@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); + if(C->is_trans_a()){ + std::swap(ord_a[0], ord_a[1]); + std::swap(shape_a[0], shape_a[1]); + } auto ord_b = layouts_->get(B)->get_order(); + if(C->is_trans_b()){ + std::swap(ord_b[0], ord_b[1]); + std::swap(shape_b[0], shape_b[1]); + } + NK = shape_a[1]; analysis::mma_layout* layout = layouts_->get(C)->to_mma(); - analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); - analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); - bool is_a_row = ord_a[0] == 1; - bool is_b_row = ord_b[0] == 1; std::vector mma_instr_shape = layout->get_mma_instr_shape(); const int mma_instr_m = mma_instr_shape[0]; @@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: const int mat_shape_n = mat_shape[1]; const int mat_shape_k = mat_shape[2]; - const int per_phase_a = swizzle_->get_per_phase(layout_a); - const int max_phase_a = swizzle_->get_max_phase(layout_a); - const int per_phase_b = swizzle_->get_per_phase(layout_b); - const int max_phase_b = swizzle_->get_max_phase(layout_b); const int num_rep_m = shapes[0] / layout->shape_per_cta(0); const int num_rep_n = shapes[1] / layout->shape_per_cta(1); @@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - if(FirstBB != CurrBB) + + // if true, this will move pointer declarations to the entry basic block + // not prefetched cases tend to be more limited in resource usage + // so we don't pre-compute ptrs to save registers + bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB); + if(licm_ptrs) builder_->SetInsertPoint(FirstBB->getTerminator()); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); @@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + ir::phi_node* phiA = dynamic_cast(A); + ir::phi_node* phiB = dynamic_cast(B); + auto register_lds2 = + [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { + if (k < 2 && is_prefetch) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); + } else + vals[{mn, k}] = val; + }; + // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride // v (s0_0(0), s1_0(2), | *num_rep_k // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) // ----------- // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) - mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, - {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, - per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); - std::vector off_a = a_loader.compute_offs(warp_m, lane); - int num_ptr_a = a_loader.get_num_ptr(); + std::function load_a; + analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared(); + bool is_a_shared = layout_a != nullptr; + if(is_a_shared) { + const int per_phase_a = swizzle_->get_per_phase(layout_a); + const int max_phase_a = swizzle_->get_max_phase(layout_a); + mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, + {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, + per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); + std::vector off_a = a_loader.compute_offs(warp_m, lane); + int num_ptr_a = a_loader.get_num_ptr(); + // pointers + std::vector ptrs_a(num_ptr_a); + if(licm_ptrs) + builder_->SetInsertPoint(CurrBB); + for(int i = 0; i < num_ptr_a; i++) + ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); + if(licm_ptrs) + builder_->SetInsertPoint(FirstBB->getTerminator()); + // loading function + load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable { + auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], + shared_next_ptr_[layout_a], off_a, ptrs_a, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(ha, m, k, inc, ha0, is_prefetch); + register_lds2(ha, m+1, k, inc, ha1, is_prefetch); + register_lds2(ha, m, k+1, inc, ha2, is_prefetch); + register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); + }; + } + else { + load_a = [&](int m, int k, int inc, bool is_prefetch) { + distributed_axis ax_n = axes_.at(a_axes_->get(A, 1)); + int ldm = ax_n.values.size(); + if(ldm != num_rep_k*4) + throw std::runtime_error("Internal compiler error when trying to fuse matmuls!"); + // std::cout << m << " " << k << std::endl; + // std::cout << idxs_[A].size() << std::endl; + // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; + // int ldm = num_rep_k*4; + Value* ha0 = UndefValue::get(fp16x2_ty); + Value* ha1 = UndefValue::get(fp16x2_ty); + Value* ha2 = UndefValue::get(fp16x2_ty); + Value* ha3 = UndefValue::get(fp16x2_ty); + ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); + ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); + ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); + ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1)); + ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0)); + ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1)); + ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0)); + ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1)); + ha[{m, k}] = ha0; + ha[{m+1, k}] = ha1; + ha[{m, k+1}] = ha2; + ha[{m+1, k+1}] = ha3; + }; + } + // | -> n (col-major) // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) // ----------- // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) - mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b, - {mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n}, + analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared(); + const int per_phase_b = swizzle_->get_per_phase(layout_b); + const int max_phase_b = swizzle_->get_max_phase(layout_b); + std::vector mma_instr_b{mma_instr_k, mma_instr_n}; + std::vector mat_shape_b{mat_shape_k, mat_shape_n}; + int k_order_b = 0; + // if(C->is_trans_b()){ + // std::swap(mma_instr_b[0], mma_instr_b[1]); + // std::swap(mat_shape_b[0], mat_shape_b[1]); + // k_order_b = k_order_b ^ 1; + // std::swap(ord_b[0], ord_b[1]); + // std::swap(shape_b[0], shape_b[1]); + // } + + mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b, + mma_instr_b, mat_shape_b, per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); std::vector off_b = b_loader.compute_offs(warp_n, lane); - int num_ptr_b = b_loader.get_num_ptr(); - builder_->SetInsertPoint(CurrBB); - // A pointer - std::vector ptrs_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); - // B pointer + if(licm_ptrs) + builder_->SetInsertPoint(CurrBB); + // pointers + int num_ptr_b = b_loader.get_num_ptr(); std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + + + // loading function + std::function load_b; + load_b = [&](int n, int k, int inc, bool is_prefetch) { + auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], + shared_next_ptr_[layout_b], off_b, ptrs_b, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(hb, n, k, inc, hb0, is_prefetch); + register_lds2(hb, n+1, k, inc, hb2, is_prefetch); + register_lds2(hb, n, k+1, inc, hb1, is_prefetch); + register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); + }; + + + + // create mma & unpack result, m, n, k are offsets in mat + auto call_mma = [&](unsigned m, unsigned n, unsigned k) { + InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + " {$0, $1, $2, $3}," " {$4, $5, $6, $7}," " {$8, $9}," " {$10, $11, $12, $13};", "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); - - // create mma & unpack result, m, n, k are offsets in mat - auto call_mma = [&](unsigned m, unsigned n, unsigned k) { unsigned cols_per_thread = num_rep_n * 2; std::vector idx = { (m + 0)*cols_per_thread + (n*2 + 0), @@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: fc[idx[2]] = extract_val(nc, std::vector{2}); fc[idx[3]] = extract_val(nc, std::vector{3}); }; - - ir::phi_node* phiA = dynamic_cast(A); - ir::phi_node* phiB = dynamic_cast(B); - - auto register_lds2 = - [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { - if (k < 2 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); - } else - vals[{mn, k}] = val; - }; - - auto load_a = [&](int m, int k, int inc, bool is_prefetch) { - auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], - shared_next_ptr_[layout_a], off_a, ptrs_a, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(ha, m, k, inc, ha0, is_prefetch); - register_lds2(ha, m+1, k, inc, ha1, is_prefetch); - register_lds2(ha, m, k+1, inc, ha2, is_prefetch); - register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); - }; - - auto load_b = [&](int n, int k, int inc, bool is_prefetch) { - auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], - shared_next_ptr_[layout_b], off_b, ptrs_b, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(hb, n, k, inc, hb0, is_prefetch); - register_lds2(hb, n+1, k, inc, hb2, is_prefetch); - register_lds2(hb, n, k+1, inc, hb1, is_prefetch); - register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); - }; - if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); @@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: i = 0; vals_[C][idx] = fcs.at(key)[i++]; }; + } /** @@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va } else if (layout->to_mma()) { shuffle_width = 4; warps_per_inner = layout->to_mma()->wpt(1); - col_per_thread = 16; + col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size(); warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; } assert(warp_j != nullptr); @@ -2403,7 +2566,8 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va Value* is_warp0 = icmp_eq(warp, i32(0)); Value* is_thread0 = icmp_eq(thread, i32(0)); Value* lane_j = urem(lane, i32(shuffle_width)); - add_barrier(); + if(warps_per_inner > 1) + add_barrier(); // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ std::pair acc; @@ -2425,13 +2589,21 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va // store partial result to shared memory auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); - if (with_index) { - call(st_shared_index, - {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + // single warp on the reduce dimension -- no need to use shmem + if(warps_per_inner==1){ + vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; + } + else{ + Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); + if (with_index) { + call(st_shared_index, + {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + } } } + if(warps_per_inner==1) + return; add_barrier(); // at this point, partial accumulator synchronized in shared memory // Just need to reduce `warp_per_inner` numbers in shared memory @@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); case ir::reduce_inst::XOR: return xor_(x, y); + default: throw std::runtime_error("unreachable"); } }; @@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); Value *base; - base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); + int off = alloc_->offset(layouts_->get(layouts_->tmp(out))); + // std::cout << off << std::endl; + base = gep(shmem_, i32(off)); base = bit_cast(base, ptr_ty(ty, 3)); std::vector n_reps; for(int i = 0; i < shape.size(); i++){ @@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); + if(in_layout->to_mma()){ + mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]); + mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]); + per_phase = 1; + max_phase = 8; + } int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; - int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); int n_shared_0 = std::max(in_vec / out_vec, 1); + int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); + if(in_layout->to_mma()){ + n_shared_0 = 8; + n_shared_1 = 1; + } BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); auto shapes = cts->get_type()->get_block_shapes(); + // store to shared Value *current = nullptr; std::map, Value*> ptrs; @@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; - int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; - int off = (off_1*shapes[in_order[0]] + off_0); + // std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl; std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ if(FirstBB->getTerminator()) @@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { builder_->SetInsertPoint(CurrBB); ptrs[key] = gep(shmems_.at(cts), {off}); } + int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; + int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; + if(in_layout->to_mma()){ + off_0 = id_0/n_shared_0*n_shared_0*8; + off_1 = id_1/n_shared_1*n_shared_1*8; + } + int off = (off_1*shapes[in_order[0]] + off_0); Value* ptr = gep(ptrs[key], {i32(off)}); ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); // asm @@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) { if(tgt_->as_nvidia()->sm() >= 80) for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; - std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;"; + std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;"; InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false); policies_[evict] = call(iasm); } @@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - // i->print(std::cout); visit_value(i); // std::cout << "done" << std::endl; } @@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) { std::vector axes(rank); std::vector ord(rank); // compute axes + // std::cout << "axes" << std::endl; for(size_t d = 0; d < shapes.size(); d++){ + // std::cout << d << " " << shapes[d] << std::endl; + // std::cout << a_axes_->get(v, d) << std::endl; if(shapes[d] > 1){ unsigned x = a_axes_->get(v, d); axes[d] = axes_.at(x); @@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) { axes[d].values = {i32(0)}; } } + // std::cout << "axes ok" << std::endl; // compute order analysis::data_layout* layout = layouts_->get(v); std::iota(ord.begin(), ord.end(), 0); @@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) { for(indices_t idx: idxs_.at(x)){ PHINode *phi = (PHINode*)vals_[x][idx]; Value *inc = vals_[x->get_incoming_value(n)][idx]; + // x->print(std::cout); phi->addIncoming(inc, block); } } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 8092ac527..8b5ad3625 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -12,8 +12,8 @@ namespace triton { namespace codegen{ namespace transform{ -coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) - : align_(align), layout_(layouts) { } +coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80) + : align_(align), layout_(layouts), has_sm80_(has_sm80) { } // simplify layout conversions using the following simple rules: @@ -64,15 +64,18 @@ void coalesce::run(ir::module &mod) { if(op->get_type()->is_block_ty()) if(op->get_type()->get_tile_rank() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) - if(layout_->get(op)->to_mma()){ + if(layout_->get(op)->to_mma()) + if(dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } // coalesce before copy_to_shared - // It's dirty, but the backend is being rewritten from scratch. :) - if(dynamic_cast(i)) + // only necessary for sm < 80 as Ampere+ can handle reduction + // on MMA layout + if(!has_sm80_) + if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(0)) if(op->get_type()->is_block_ty()) if(op->get_type()->get_tile_rank() == 2) @@ -89,7 +92,8 @@ void coalesce::run(ir::module &mod) { if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) if(x->get_type()->get_tile_rank()==2) - if(layout_->get(x)->to_mma()){ + if(layout_->get(x)->to_mma()) + if(!has_sm80_ || dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ builder.set_insert_point_after(x); ir::instruction* new_x = ir::cvt_layout_inst::create(x); builder.insert(new_x); diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index c223d2413..4606b0f57 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -1,8 +1,10 @@ +#include "triton/codegen/analysis/layout.h" #include "triton/codegen/transform/cts.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" +#include "triton/ir/utils.h" #include namespace triton { @@ -10,9 +12,9 @@ namespace codegen{ namespace transform{ -inline bool is_shmem_op(ir::instruction* i, int op) { +bool cts::is_shmem_op(ir::instruction* i, int op) { if(i->get_id() == ir::INST_DOT) - return op==0 || op==1; + return op == 0 || op == 1; if(i->get_id() == ir::INST_COPY_FROM_SHARED) return op==0; if(i->get_id() == ir::INST_TRANS) @@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) { return false; } -inline bool is_shmem_res(ir::value* v){ +bool cts::is_shmem_res(ir::value* v){ ir::instruction* i = dynamic_cast(v); if(!i) return false; @@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){ // run pass on module -void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { +void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map& copies) { auto *i = dynamic_cast(x); // not an instruction if(!i) { @@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, // phi node if(auto* phi = dynamic_cast(x)) { for(unsigned i = 0; i < phi->get_num_incoming(); ++i) - add_copy(phi, phi->get_incoming_value(i), builder, to_shared); + add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies); return; } // already in shared memory @@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, } else copy = builder.create_copy_from_shared(x); - parent->replace_uses_of_with(x, copy); + copies.insert({x, copy}); + parent->replace_uses_of_with(x, copies.at(x)); } void cts::run(ir::module &mod) { - // Add shared copies - ir::builder &builder = mod.get_builder(); - for(ir::function* fn: mod.get_function_list()){ - for(ir::basic_block* block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - size_t num_op = i->get_num_operands(); - // copy to shared operands - for(size_t k = 0; k < num_op; k++) - if(is_shmem_op(i, k)){ - add_copy(i, i->get_operand(k), builder, true); - } - // copy from shared operands - for(size_t k = 0; k < num_op; k++) - if(!dynamic_cast(i) && - !is_shmem_op(i,k) && - is_shmem_res(i->get_operand(k))){ - add_copy(i, i->get_operand(k), builder, false); - } + // Precompute where copies should be added + std::set shmem_ops; + std::set shmem_res; + ir::for_each_instruction(mod, [&](ir::instruction* i) { + if(i->get_id() == ir::INST_DOT){ + ir::dot_inst* dot = dynamic_cast(i); + ir::value* lhs = i->get_operand(0); + ir::type* ty = lhs->get_type()->get_scalar_ty(); + analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma(); + // TODO: V100 + bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a()); + if(is_lhs_shmem) + shmem_ops.insert(lhs); + shmem_ops.insert(i->get_operand(1)); } - } + if(i->get_id() == ir::INST_COPY_FROM_SHARED) + shmem_ops.insert(i->get_operand(0)); + if(i->get_id() == ir::INST_TRANS) + shmem_ops.insert(i->get_operand(0)); + if(i->get_id() == ir::INST_TRANS || + i->get_id() == ir::INST_COPY_TO_SHARED || + i->get_id() == ir::INST_MASKED_LOAD_ASYNC) + shmem_res.insert(i); + }); + + // Add shared copies + std::map copies; + ir::builder &builder = mod.get_builder(); + ir::for_each_instruction(mod, [&](ir::instruction* i) { + size_t num_op = i->get_num_operands(); + for(size_t k = 0; k < num_op; k++){ + ir::value* op = i->get_operand(k); + // copy to shared operands + bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end(); + if(is_shmem_op) + add_copy(i, op, builder, true, copies); + } + }); } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index c25a252a8..a7d3f8240 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name())); + ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name())); add->replace_all_uses_with(new_dot); return true; } diff --git a/lib/ir/basic_block.cc b/lib/ir/basic_block.cc index 93caef2c3..0bbc3af0f 100644 --- a/lib/ir/basic_block.cc +++ b/lib/ir/basic_block.cc @@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) auto* curr_phi = dynamic_cast(i); if(!curr_phi) break; - curr_phi->replace_uses_of_with(before, after); + // curr_phi->replace_uses_of_with(before, after); + for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx) + if (curr_phi->get_incoming_block(idx) == before) + curr_phi->set_incoming_block(idx, after); } } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 4060f23bb..510994fd8 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); } -value *builder::create_store(value *ptr, value *val){ - return insert(unmasked_store_inst::create(ptr, val)); +value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){ + return insert(unmasked_store_inst::create(ptr, val, eviction)); } value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); } -value *builder::create_masked_store(value *ptr, value *val, value *mask){ - return insert(masked_store_inst::create(ptr, val, mask)); +value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){ + return insert(masked_store_inst::create(ptr, val, mask, eviction)); } //===----------------------------------------------------------------------===// @@ -412,8 +412,8 @@ value *builder::create_log(value *arg){ return insert(log_inst::create(arg)); } -value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) { - return insert(dot_inst::create_nn(A, B, C, allow_tf32)); +value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) { + return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32)); } value *builder::create_trans(value *A, const std::vector& perm) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 325976504..dbee5e0ee 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){ // Add incoming void phi_node::add_incoming(value *v, basic_block *block){ + assert(v && "PHI node got a null value!!"); resize_ops(get_num_operands() + 1); blocks_.resize(get_num_operands() + 1); set_incoming_value(get_num_operands() - 1, v); @@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vectorget_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile) { } // load @@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, // store -store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) - : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next) +store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next) + : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next) { } // unmasked_store -unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, +unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction, const std::string &name, instruction *next) - : store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) { + : store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) { set_operand(0, ptr); set_operand(1, val); } -unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, +unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction, const std::string &name, instruction *next) { - return new unmasked_store_inst(ptr, val, name, next); + return new unmasked_store_inst(ptr, val, eviction, name, next); } // masked store -masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, +masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, const std::string &name, instruction *next) - : store_inst(ptr, INST_MASKED_STORE, 3, name, next) { + : store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) { set_operand(0, ptr); set_operand(1, val); set_operand(2, mask); } -masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { - return new masked_store_inst(ptr, val, mask, name, next); +masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, + const std::string &name, instruction *next) { + return new masked_store_inst(ptr, val, mask, eviction, name, next); } //===----------------------------------------------------------------------===// @@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next) - : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { + : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){ set_operand(0, A); set_operand(1, B); set_operand(2, C); diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc index cbfb4baf9..9abaef5c0 100644 --- a/lib/ir/utils.cc +++ b/lib/ir/utils.cc @@ -43,6 +43,15 @@ std::vector cfg::reverse_post_order(function* fn) { return result; } +void for_each_instruction_backward(module &mod, const std::function &do_work) { + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: cfg::post_order(fn)){ + auto inst_list = block->get_inst_list(); + for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++) + do_work(*it); + } +} + void for_each_instruction(module &mod, const std::function &do_work) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: cfg::reverse_post_order(fn)) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a5fb0acba..6987d0c26 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -840,10 +840,10 @@ def test_permute(dtype_str, shape, perm, device='cuda'): @pytest.mark.parametrize("epilogue, allow_tf32, dtype", [(epilogue, allow_tf32, dtype) - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] - for dtype in ['float32', 'int8'] - if not (allow_tf32 and (dtype == 'int8'))]) + for dtype in ['float16'] + if not (allow_tf32 and (dtype in ['float16']))]) def test_dot(epilogue, allow_tf32, dtype, device='cuda'): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) if cc < 80: @@ -852,21 +852,30 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): elif dtype == 'float32' and allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") + M, N, K = 128, 128, 64 + num_warps = 8 + trans_a, trans_b = False, False + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, + W, stride_wn, stride_wl, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - ALLOW_TF32: tl.constexpr): + ALLOW_TF32: tl.constexpr, + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, + TRANS_A: tl.constexpr, TRANS_B: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32) + z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32) if ADD_MATRIX: z += tl.load(Zs) if ADD_ROWS: @@ -875,39 +884,65 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if ADD_COLS: ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + # tl.store(Zs, z) + # tl.debug_barrier() + z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A) tl.store(Zs, z) # input - M, N, K = 64, 64, 32 rs = RandomState(17) - x = numpy_random((M, K), dtype_str=dtype, rs=rs) - y = numpy_random((K, N), dtype_str=dtype, rs=rs) + x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1 + y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1 + w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1 if allow_tf32: x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) + w_tri = to_triton(w, device=device) # triton result - z = numpy_random((M, N), dtype_str=dtype, rs=rs) + z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1 z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), + w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), + TRANS_A=trans_a, TRANS_B=trans_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', ADD_ROWS=epilogue == 'add-rows', ADD_COLS=epilogue == 'add-cols', - ALLOW_TF32=allow_tf32) + DO_SOFTMAX=epilogue == 'softmax', + CHAIN_DOT=epilogue == 'chain-dot', + ALLOW_TF32=allow_tf32, + num_warps=num_warps) # torch result - z_ref = np.matmul(x, y) + x_ref = x.T if trans_a else x + y_ref = y.T if trans_b else y + z_ref = np.matmul(x_ref, y_ref) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': z_ref += z[:, 0][:, None] if epilogue == 'add-cols': z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + z_ref = np.matmul(z_ref.T if trans_a else z_ref, w) # compare + # print(z_ref[:,0], z_tri[:,0]) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 90a031a30..27c9e1bfe 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -211,7 +211,7 @@ class ValueConstructor: return phi v = unique_handles.pop() phi.handle.replace_all_uses_with(v) - phi.handle.erase_from_parent() + # phi.handle.erase_from_parent() # TODO: remove trivial phis recursively return triton.language.tensor(v, phi.type) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1c54ef2c7..3ba2d2f21 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -732,7 +732,7 @@ def reshape(input, shape, _builder=None): @builtin -def dot(input, other, allow_tf32=True, _builder=None): +def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. @@ -744,7 +744,7 @@ def dot(input, other, allow_tf32=True, _builder=None): :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, _builder) + return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder) # ----------------------- @@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", @builtin -def store(pointer, value, mask=None, _builder=None): +def store(pointer, value, eviction_policy="", mask=None, _builder=None): """ Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. @@ -799,7 +799,7 @@ def store(pointer, value, mask=None, _builder=None): value = _to_tensor(value, _builder) if mask is not None: mask = _to_tensor(mask, _builder) - return semantic.store(pointer, value, mask, _builder) + return semantic.store(pointer, value, mask, eviction_policy, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 15a5cb648..2d137b904 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -648,6 +648,18 @@ def cast(input: tl.tensor, # ===----------------------------------------------------------------------===// +def _parse_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], @@ -684,14 +696,7 @@ def load(ptr: tl.tensor, raise ValueError(f"Cache modifier {cache_modifier} not supported") # eviction policy - eviction = ir.EVICTION_POLICY.NORMAL # default - if eviction_policy: - if eviction_policy == "evict_last": - eviction = ir.EVICTION_POLICY.EVICT_LAST - elif eviction_policy == "evict_first": - eviction = ir.EVICTION_POLICY.EVICT_FIRST - else: - raise ValueError(f"Eviction policy {eviction_policy} not supported") + eviction = _parse_eviction_policy(eviction_policy) if ptr.type.is_block(): shape = ptr.type.get_block_shapes() @@ -721,6 +726,7 @@ def load(ptr: tl.tensor, def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], + eviction_policy: str, builder: ir.builder) -> tl.tensor: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) @@ -735,14 +741,15 @@ def store(ptr: tl.tensor, elt_ty_ptr = tl.int8 ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) - + # eviction policy + eviction = _parse_eviction_policy(eviction_policy) # cast to target data-type val = cast(val, elt_ty, builder) if not mask: - return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") - return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void) ######### # atomic @@ -897,27 +904,31 @@ def atomic_xchg(ptr: tl.tensor, # ===----------------------------------------------------------------------===// -def dot(lhs: tl.tensor, - rhs: tl.tensor, +def dot(a: tl.tensor, + b: tl.tensor, + trans_a: bool, + trans_b: bool, allow_tf32: bool, builder: ir.builder) -> tl.tensor: - assert lhs.type.is_block() and rhs.type.is_block() - assert len(lhs.shape) == 2 and len(rhs.shape) == 2 - assert lhs.shape[-1] == rhs.shape[0] - assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\ + in_a = 1 if not trans_a else 0 + in_b = 1 if trans_b else 0 + assert a.type.is_block() and b.type.is_block() + assert len(a.shape) == 2 and len(b.shape) == 2 + assert a.shape[in_a] == b.shape[in_b] + assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\ "small blocks not supported!" - if lhs.type.scalar.is_int(): + if a.type.scalar.is_int(): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 else: _0 = builder.get_float32(0) ret_scalar_ty = tl.float32 - M = lhs.type.shape[0] - N = rhs.type.shape[1] + M = a.type.shape[in_a ^ 1] + N = b.type.shape[in_b ^ 1] _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) + ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32) + return tl.tensor(ret, ret_ty) # ===----------------------------------------------------------------------===// diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py new file mode 100644 index 000000000..eb9b40c60 --- /dev/null +++ b/python/tutorials/06-fused-attention.py @@ -0,0 +1,198 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kk, stride_kn, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_qm = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + q = tl.load(q_ptrs) + for start_n in range(0, start_qm + 1): + # -- compute qk ---- + k = tl.load(k_ptrs) + qk = tl.dot(q, k) + qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + p = p.to(tl.float16) + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs) + acc += tl.dot(p, v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # r_ptrs += BLOCK_N + l_i = l_i_new + m_i = m_i_new + + start_qm = tl.program_id(0) + offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_out + tl.store(out_ptrs, acc) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v): + BLOCK = 128 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-2] + assert Lq == Lk + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=64, num_warps=4, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + return o + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)]) +def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16): + torch.manual_seed(20) + q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) + v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + # triton implementation + tri_out = attention(q, k, v) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_qk = torch.matmul(q, k) + for z in range(Z): + for h in range(H): + ref_qk[:, :, M == 0] = float("-inf") + ref_qk = torch.softmax(ref_qk, dim=-1) + ref_out = torch.matmul(ref_qk, v) + # compare + triton.testing.assert_almost_equal(ref_out, tri_out) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64 +# vary batch size for fixed heads / seq +batch_bench = triton.testing.Benchmark( + x_names=['BATCH'], + x_vals=[2**i for i in range(0, 8)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}', + args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16} +) +# vary seq length for fixed head and batch=4 +seq_bench = triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 16)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}', + args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16} +) + + +@triton.testing.perf_report([batch_bench, seq_bench]) +def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"): + warmup = 25 + rep = 500 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + fn = lambda: attention(q, k, v) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + + +bench_flash_attention.run(save_path='.', print_data=True) From feb7a2a0dc617fc00221cbf87b99d910adeaf09e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 28 Jun 2022 00:24:02 -0700 Subject: [PATCH 141/215] [FRONTEND] Hotfix for `store` argument order (#567) --- python/triton/language/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3ba2d2f21..ee99aab2e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", @builtin -def store(pointer, value, eviction_policy="", mask=None, _builder=None): +def store(pointer, value, mask=None, eviction_policy="", _builder=None): """ Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. From 1895ceaa2d9748f14c89a65d8ac76d443ea5c8ed Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 29 Jun 2022 18:39:10 +0200 Subject: [PATCH 142/215] [TUTORIAL] Fix f-string for older python (#569) fixes issue #568 --- python/triton/code_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 27c9e1bfe..30a79bcc9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -956,7 +956,7 @@ class Kernel: return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): - assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"{num_warps=} must be a power of 2." + assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} wargs = list(wargs) From 1bbb2430d9437ba0aeb54082971c05221799d45c Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 29 Jun 2022 17:00:22 -0700 Subject: [PATCH 143/215] [TUTORIALS] adjust heuristics for dwdb kernel (#565) --- python/tutorials/05-layer-norm.py | 48 +++++++++++++++++-------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 9880b428f..333cb80ec 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb( cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] - a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) - dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) - mean = tl.load(Mean + rows, mask=rows < M, other=0.) - rstd = tl.load(Var + rows, mask=rows < M, other=0.) - a_hat = (a - mean[:, None]) * rstd[:, None] - dw += dout * a_hat - db += dout + UNROLL: tl.constexpr = 4 + for i in range(0, M, BLOCK_SIZE_M * UNROLL): + for j in range(UNROLL): + rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) + dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) + mean = tl.load(Mean + rows, mask=rows < M, other=0.) + rstd = tl.load(Var + rows, mask=rows < M, other=0.) + a_hat = (a - mean[:, None]) * rstd[:, None] + dw += dout * a_hat + db += dout sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(DW + cols, sum_dw, mask=cols < N) @@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function): BLOCK_SIZE_N=ctx.BLOCK_SIZE, num_warps=ctx.num_warps, ) - # accumulate partial sums in separate kernel + if N > 10240: + BLOCK_SIZE_N = 128 + BLOCK_SIZE_M = 32 + num_warps = 4 + else: + # maximize occupancy for small N + BLOCK_SIZE_N = 16 + BLOCK_SIZE_M = 16 + num_warps = 8 grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] _layer_norm_bwd_dwdb[grid]( a, dout, @@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function): dbias, M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps ) - return (da, None, dweight, dbias, None, None, - None, None, None, None, - None, - None, None, None, - None, - None, None, None, - None, None, None, - None, None, None) + return (da, None, dweight, dbias, None) def layer_norm(a, normalized_shape, weight, bias, eps): From f733327ba4cda869562eacb3e7bb43b9c4b9a132 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 29 Jun 2022 17:05:13 -0700 Subject: [PATCH 144/215] [BACKEND][CODEGEN] Disabling L2 residency control by default (#570) --- lib/codegen/selection/generator.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e69b0acee..48a15c4af 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -802,6 +802,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + has_l2_evict_policy = false; // has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string @@ -966,6 +967,7 @@ void generator::visit_store_inst(ir::store_inst * x){ vec = std::min(2, aln); } bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + has_l2_evict_policy = false; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store From a74cce375f9a0b6e29dbf5b2473828ae45bc5c61 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 30 Jun 2022 17:32:07 -0700 Subject: [PATCH 145/215] [FRONTEND] Raise broadcast error (#555) --- python/test/unit/language/test_core.py | 32 ++++++++++++++++++++------ python/triton/language/semantic.py | 5 ++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6987d0c26..92c854f06 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -385,6 +385,8 @@ def test_index1d(expr, dtype_str, device='cuda'): rank_y = expr.count(',') + 1 shape_x = [32 for _ in range(rank_x)] shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] # Triton kernel @triton.jit @@ -395,12 +397,17 @@ def test_index1d(expr, dtype_str, device='cuda'): z = GENERATE_TEST_HERE tl.store(Z_PTR_EXPR, z) - to_replace = { - 'X_PTR_EXPR': make_ptr_str('X', shape_x), - 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), - 'GENERATE_TEST_HERE': expr, - } - kernel = patch_kernel(kernel, to_replace) + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) # torch result x = numpy_random(shape_x, dtype_str=dtype_str) @@ -409,10 +416,21 @@ def test_index1d(expr, dtype_str, device='cuda'): # triton result z_tri = to_triton(np.empty_like(z_ref), device=device) x_tri = to_triton(x) - kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) # compare assert (z_ref == to_numpy(z_tri)).all() + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + except triton.code_gen.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + # --------------- # test tuples diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2d137b904..9025319d6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -482,6 +482,11 @@ def broadcast_impl_shape(input: tl.tensor, raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") if shape == src_shape: return input + for i in range(len(src_shape)): + if shape[i] != src_shape[i] and src_shape[i] != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({src_shape[1]}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") ret_ty = tl.block_type(input.type.scalar, shape) return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) From 4bf509889bae68ef20668ff6edadbb293c7e431f Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 1 Jul 2022 12:17:22 -0700 Subject: [PATCH 146/215] [BUILD] Change the default build type to Release (#571) --- CMakeLists.txt | 5 ----- python/setup.py | 23 +++++++++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c3aadf9c7..e2bc7f309 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,11 +3,6 @@ include(ExternalProject) set(CMAKE_CXX_STANDARD 17) -if(NOT TRITON_LLVM_BUILD_DIR) - set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR}) -endif() - - project(triton) include(CTest) if(NOT WIN32) diff --git a/python/setup.py b/python/setup.py index af1fa3068..7ed6ab444 100644 --- a/python/setup.py +++ b/python/setup.py @@ -7,7 +7,6 @@ import shutil import subprocess import sys import tarfile -import tempfile import urllib.request from distutils.version import LooseVersion @@ -15,6 +14,20 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py +def check_env_flag(name: str, default: str = "") -> bool: + return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def get_build_type(): + if check_env_flag("DEBUG"): + return "Debug" + elif check_env_flag("REL_WITH_DEB_INFO"): + return "RelWithDebInfo" + else: + return "Release" + + def get_llvm(): # tries to find system LLVM versions = ['-11.0', '-11', '-11-64'] @@ -80,15 +93,10 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories - build_suffix = 'debug' if self.debug else 'release' - llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix) if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - if not os.path.exists(llvm_build_dir): - os.makedirs(llvm_build_dir) # python directories python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] cmake_args = [ @@ -99,11 +107,10 @@ class CMakeBuild(build_ext): "-DLLVM_LIBRARY_DIR=" + llvm_library_dir, # '-DPYTHON_EXECUTABLE=' + sys.executable, # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', - "-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir, "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) ] # configuration - cfg = "Debug" if self.debug else "Release" + cfg = get_build_type() build_args = ["--config", cfg] if platform.system() == "Windows": From 22105bc33b14ce4f9af669a50b6031bdb6ebb8c4 Mon Sep 17 00:00:00 2001 From: vesuppi Date: Sun, 3 Jul 2022 15:25:37 -0700 Subject: [PATCH 147/215] [FRONTEND] Added type check in semantic arange (#572) --- python/triton/language/semantic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 9025319d6..a31fec384 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -436,6 +436,9 @@ def not_equal(input: tl.tensor, def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + shape = [end - start] ret_ty = tl.block_type(tl.int32, shape) return tl.tensor(builder.get_range(start, end), ret_ty) From 4a399a7e4058f8f54f4d2c6812462b3eb2bf2c9e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 6 Jul 2022 20:03:04 -0700 Subject: [PATCH 148/215] [BACKEND] Fix some bugs (atomics, a segfault...) (#577) This should fix #558 , #573 and #574 --- lib/codegen/selection/generator.cc | 41 +++++++++++++++++----- lib/codegen/transform/coalesce.cc | 48 ++++---------------------- python/test/unit/language/test_core.py | 23 ++++++++++++ python/triton/language/core.py | 11 ++++++ 4 files changed, 73 insertions(+), 50 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 48a15c4af..e2303b990 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1285,13 +1285,35 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { // vector size int vec = 1; + Value *mask = builder_->getInt1(true); if(atom->get_type()->is_block_ty()){ + auto shape = atom->get_type()->get_block_shapes(); int ld = ords_.at(ptr)[0]; unsigned alignment = alignment_->get(ptr, ld); vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); + // mask out inactive threads + analysis::data_layout* layout = layouts_->get(val); + auto curr_axes = a_axes_->get(val); + auto layt_axes = layout->get_axes(); + for(unsigned k = 0; k < layt_axes.size(); k++){ + unsigned ax = layt_axes.at(k); + distributed_axis dax = axes_.at(ax); + // axis is part of the original layout: thread id should be 0 + // but not the current layout + if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end()) + mask = and_(mask, icmp_eq(dax.thread_id, i32(0))); + } + // last axis may spillover + Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0); + int per_thread = 1; + for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; } + int numel = 1; + for(int s: layout->get_shape()) { numel *= s; } + mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel))); } + for(int i = 0; i < idxs_.at(val).size(); i += vec){ auto idx = idxs_[val][i]; Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); @@ -1299,6 +1321,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); Value *rmw_ptr = vals_[ptr][idx]; Value *rmw_msk = vals_[msk][idx]; + rmw_msk = and_(rmw_msk, mask); if(vec == 1) rmw_val = extract_elt(rmw_val, i32(0)); Type* ty = rmw_val->getType(); @@ -3400,20 +3423,20 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0); + Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0); auto order = layout->get_order(); const auto& shape = layout->get_shape(); // Delinearize size_t dim = shape.size(); - std::vector thread_id(dim); + std::vector thread_ids(dim); for(unsigned k = 0; k < dim - 1; k++){ Constant *dim_k = i32(layout->mts(order[k])); - Value *rem = urem(u_thread_id, dim_k); - u_thread_id = udiv(u_thread_id, dim_k); - thread_id[order[k]] = rem; + Value *rem = urem(thread_id, dim_k); + thread_id = udiv(thread_id, dim_k); + thread_ids[order[k]] = rem; } Constant *dim_k = i32(layout->mts(order[dim - 1])); - thread_id[order[dim - 1]] = urem(u_thread_id, dim_k); + thread_ids[order[dim - 1]] = urem(thread_id, dim_k); // Create axes for(unsigned k = 0; k < dim; k++) { @@ -3421,15 +3444,15 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) { int mts = layout->mts(k); std::string str_k = std::to_string(k); Value *contiguous_k = i32(nts); - Value *scaled_thread_id = mul(thread_id[k], contiguous_k); + Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k); unsigned per_cta = layout->shape_per_cta(k); unsigned per_thread = nts * shape[k] / per_cta; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ unsigned offset = n / nts * per_cta + n % nts; - idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); + idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; + axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]}; } } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 8b5ad3625..862ad1efe 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -15,42 +15,6 @@ namespace transform{ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80) : align_(align), layout_(layouts), has_sm80_(has_sm80) { } - -// simplify layout conversions using the following simple rules: -// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y)) -//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){ -// ir::value* _op = inst->get_operand(0); -// ir::instruction* op = dynamic_cast(_op); -// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma(); -// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma(); -// std::cout << 1 << std::endl; -// // i must be layout conversion instruction -// if(!mma_in && !mma_out) -// return inst; -// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT; -// if((mma_in || mma_out) && is_op_cvt && -// (layout_->get(inst) == layout_->get(op->get_operand(0)))) -// return op->get_operand(0); -// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y)) -// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR) -// return inst; -// std::cout << 1 << std::endl; -// for(size_t i = 0; i < op->get_num_operands(); i++){ -// ir::value* arg_i = op->get_operand(i); -// builder.set_insert_point(op); -// // create new layout transform -// ir::instruction* new_arg_i = inst->clone(); -// builder.insert(new_arg_i); -// // set the right args -// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); -// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder)); -// } -// std::cout << 2 << std::endl; -// return op; -//} - void coalesce::run(ir::module &mod) { std::set invalidated; ir::builder& builder = mod.get_builder(); @@ -62,7 +26,7 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) + if(op->get_type()->get_tile_ranks1() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()) if(dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ @@ -78,7 +42,7 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(0)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) + if(op->get_type()->get_tile_ranks1() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); @@ -91,7 +55,7 @@ void coalesce::run(ir::module &mod) { // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) - if(x->get_type()->get_tile_rank()==2) + if(x->get_type()->get_tile_ranks1()==2) if(layout_->get(x)->to_mma()) if(!has_sm80_ || dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ builder.set_insert_point_after(x); @@ -111,9 +75,11 @@ void coalesce::run(ir::module &mod) { auto out_contig = align_->contiguous(ptr); auto val_inst = dynamic_cast(val); if(!val_inst) - break; + continue; if(dynamic_cast(val)) - break; + continue; + if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1) + continue; std::vector in_contig; std::vector queue = {val_inst}; std::set seen; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 92c854f06..d032d1e39 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -532,6 +532,29 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) +@pytest.mark.parametrize("axis", [0, 1]) +def test_tensor_atomic_rmw(axis, device="cuda"): + shape0, shape1 = 8, 8 + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + tl.atomic_add(Z + off0, z) + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis) + # triton result + x_tri = to_triton(x, device=device) + z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device) + kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + def test_atomic_cas(): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ee99aab2e..cc0db5566 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -370,6 +370,17 @@ class constexpr: def __call__(self, *args, **kwds): return self.value(*args, **kwds) + def to(self, dtype, bitcast=False, _builder=None): + if dtype in [float8, float16, bfloat16]: + raise ValueError("floating point constexpr must be float64") + if dtype.is_int(): + ret_ty = int + elif dtype.is_bool(): + ret_ty = bool + elif dtype.is_floating(): + ret_ty = float + return constexpr(ret_ty(self.value)) + class tensor: # infer dtype from ir type From c9a2b9c7d414109173b30f2672daa8b2f7334080 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 11 Jul 2022 14:37:59 -0700 Subject: [PATCH 149/215] [FRONTEND] Add missing args to get_simd_tflops() (#578) --- python/triton/ops/matmul_perf_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 9c10b88d8..004f236b9 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): def get_tflops(backend, device, num_ctas, num_warps, dtype): cc = _triton.runtime.cc(backend, device) if cc < 80 and dtype == torch.float32: - return get_simd_tflops() + return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) From d5eb9bc23000c843309980eafb8005f4f9f1e1b4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 11 Jul 2022 15:43:46 -0700 Subject: [PATCH 150/215] [tutorial] Added bwd in fused attention example (#579) Doesn't work on V100 --- python/tutorials/06-fused-attention.py | 276 +++++++++++++++++++------ 1 file changed, 215 insertions(+), 61 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index eb9b40c60..030fe2c2b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -1,3 +1,8 @@ +""" +Fused Attention +=============== +""" + import pytest import torch @@ -7,25 +12,25 @@ import triton.language as tl @triton.jit def _fwd_kernel( - Q, K, V, + Q, K, V, sm_scale, TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kk, stride_kn, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): - start_qm = tl.program_id(0) + start_m = tl.program_id(0) off_hz = tl.program_id(1) # initialize offsets - offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk # Initialize pointers to Q, K, V q_ptrs = Q + off_q @@ -33,17 +38,20 @@ def _fwd_kernel( v_ptrs = V + off_v # initialize pointer to m and l t_ptrs = TMP + off_hz * N_CTX + offs_m - - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout q = tl.load(q_ptrs) - for start_n in range(0, start_qm + 1): + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs) - qk = tl.dot(q, k) - qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf")) + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) p = tl.exp(qk - m_ij[:, None]) @@ -57,23 +65,21 @@ def _fwd_kernel( # scale p p_scale = beta / l_i_new p = p * p_scale[:, None] - p = p.to(tl.float16) # scale acc acc_scale = l_i / l_i_new * alpha tl.store(t_ptrs, acc_scale) acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs) + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) acc += tl.dot(p, v) - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - # r_ptrs += BLOCK_N + # update m_i and l_i l_i = l_i_new m_i = m_i_new - - start_qm = tl.program_id(0) - offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m l_ptrs = L + off_hz * N_CTX + offs_m m_ptrs = M + off_hz * N_CTX + offs_m @@ -81,18 +87,122 @@ def _fwd_kernel( tl.store(m_ptrs, m_i) # initialize pointers to output offs_n = tl.arange(0, BLOCK_DMODEL) - off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_out + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o tl.store(out_ptrs, acc) +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v): + def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints - Lq, Lk = q.shape[-1], k.shape[-2] + Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) @@ -100,7 +210,7 @@ class _attention(torch.autograd.Function): L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( - q, k, v, + q, k, v, sm_scale, tmp, L, m, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -115,30 +225,76 @@ class _attention(torch.autograd.Function): ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = 64 return o + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + return dq, dk, dv, None + attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)]) -def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16): +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) - k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) - v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) - # triton implementation - tri_out = attention(q, k, v) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - ref_qk = torch.matmul(q, k) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale for z in range(Z): for h in range(H): - ref_qk[:, :, M == 0] = float("-inf") - ref_qk = torch.softmax(ref_qk, dim=-1) - ref_out = torch.matmul(ref_qk, v) + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None # compare triton.testing.assert_almost_equal(ref_out, tri_out) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) try: @@ -147,21 +303,9 @@ try: except BaseException: HAS_FLASH = False -BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64 -# vary batch size for fixed heads / seq -batch_bench = triton.testing.Benchmark( - x_names=['BATCH'], - x_vals=[2**i for i in range(0, 8)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}', - args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16} -) +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -seq_bench = triton.testing.Benchmark( +configs = [triton.testing.Benchmark( x_names=['N_CTX'], x_vals=[2**i for i in range(10, 16)], line_arg='provider', @@ -169,28 +313,38 @@ seq_bench = triton.testing.Benchmark( line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}', - args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16} -) + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} +) for mode in ['bwd']] -@triton.testing.perf_report([batch_bench, seq_bench]) -def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"): +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] warmup = 25 - rep = 500 + rep = 100 if provider == "triton": - q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) - fn = lambda: attention(q, k, v) + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms if provider == "flash": lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) - qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms From 971f5782b40d609a648dc35ae11c80993b5ce427 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 11 Jul 2022 18:56:48 -0700 Subject: [PATCH 151/215] [tutorials] Added flash attention credits in tutorial --- python/tutorials/06-fused-attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 030fe2c2b..89aadb1b4 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -1,6 +1,7 @@ """ Fused Attention =============== +This is a Triton implementation of the Flash Attention algorithm (Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf) """ import pytest From 4912916c11f68037510fce44bdbdcf1292550cb8 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 13 Jul 2022 15:52:21 -0700 Subject: [PATCH 152/215] [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) --- include/triton/codegen/extern_lib.h | 89 + include/triton/codegen/pass.h | 11 +- include/triton/codegen/selection/generator.h | 14 +- include/triton/ir/builder.h | 6 + include/triton/ir/enums.h | 2 + include/triton/ir/instructions.h | 21 + include/triton/ir/visitor.h | 4 + lib/codegen/extern_lib.cc | 63 + lib/codegen/pass.cc | 89 +- lib/codegen/selection/generator.cc | 39 +- lib/driver/llvm.cc | 7 +- lib/ir/builder.cc | 13 + lib/ir/instructions.cc | 22 + python/setup.py | 2 +- python/src/triton.cc | 61 +- python/test/unit/language/test_core.py | 46 + python/triton/code_gen.py | 19 +- python/triton/language/__init__.py | 2 +- python/triton/language/core.py | 6 +- python/triton/language/extern.py | 107 ++ python/triton/language/libdevice.10.bc | Bin 0 -> 469572 bytes python/triton/language/libdevice.py | 1661 ++++++++++++++++++ python/triton/tools/build_extern.py | 340 ++++ python/tutorials/07-libdevice-function.py | 74 + 24 files changed, 2634 insertions(+), 64 deletions(-) create mode 100644 include/triton/codegen/extern_lib.h create mode 100644 lib/codegen/extern_lib.cc create mode 100644 python/triton/language/extern.py create mode 100644 python/triton/language/libdevice.10.bc create mode 100644 python/triton/language/libdevice.py create mode 100644 python/triton/tools/build_extern.py create mode 100644 python/tutorials/07-libdevice-function.py diff --git a/include/triton/codegen/extern_lib.h b/include/triton/codegen/extern_lib.h new file mode 100644 index 000000000..c161ff142 --- /dev/null +++ b/include/triton/codegen/extern_lib.h @@ -0,0 +1,89 @@ +#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_ +#define _TRITON_CODE_GEN_EXTERN_LIB_H_ + +#include +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" + +namespace triton { +namespace codegen { + +/// +/// \brief ExternLib is a class that represents a library of external functions. +/// +class ExternLib { + public: + ExternLib(const std::string &name, const std::string &path) + : name_(name), path_(path) {} + + virtual ~ExternLib() = default; + + virtual const std::string &name() const { return name_; } + + virtual const std::string &path() const { return path_; } + + /// + /// \brief Load the library and return the module. + /// + std::unique_ptr load(llvm::LLVMContext &ctx); + + /// + /// \brief Link the module into the given module. + /// + void link(std::unique_ptr &llvm, + std::unique_ptr &mod); + + /// + /// \brief Run load, link, and opt on the module. + /// + virtual void install(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) { + auto mod = load(ctx); + link(llvm, mod); + opt(ctx, llvm); + } + + /// + /// \brief Run opt on the module. + /// + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) = 0; + + private: + std::string name_; + std::string path_; +}; + +/// +/// \brief ExternLibMap is a map of ExternLibs from their names to their paths. +/// +typedef std::map> ExternLibMap; + +/// +/// \brief Concrete class for NVIDIA's libdevice library. +/// +class LibDevice final : public ExternLib { + public: + LibDevice(const std::string &name, const std::string &path) + : ExternLib(name, path) {} + + virtual ~LibDevice() = default; + + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) override; +}; + +/// +/// \brief Create an ExternLib instance based on the name and path. +/// +std::unique_ptr create_extern_lib(const std::string &lib_name, + const std::string &lib_path); + +} // namespace codegen +} // namespace triton + +#endif diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 0c8f11315..95b00b807 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -3,6 +3,7 @@ #include +#include "extern_lib.h" namespace llvm{ class Module; @@ -30,12 +31,10 @@ namespace codegen{ // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, - codegen::target* target, - int sm, int num_warps, - int num_stages, int &shared_static); - - +std::unique_ptr add_passes_to_emit_bin( + ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target, + int num_warps, int num_stages, int &shared_static, + const ExternLibMap &extern_libs); } } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index b408a46ca..7867c356b 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -6,6 +6,7 @@ #include "triton/ir/visitor.h" #include "triton/ir/instructions.h" #include "triton/codegen/analysis/layout.h" +#include "triton/codegen/extern_lib.h" #include // forward @@ -199,6 +200,7 @@ private: void visit_make_range(ir::make_range*); void visit_clock_inst(ir::clock_inst*); void visit_globaltimer_inst(ir::globaltimer_inst*); + void visit_extern_elementwise_inst(ir::extern_elementwise_inst*); // void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); @@ -209,18 +211,26 @@ private: void visit_argument(ir::argument*); void visit(ir::module &, llvm::Module &); - // layouts void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_shared(analysis::shared_layout*); + // Add a new external library based on given name and path if it doesn't exist + void add_extern_lib(const std::string &lib_name, const std::string &lib_path); -private: + // Get all external libraries + const ExternLibMap &get_extern_lib_map() { + return extern_lib_map_; + } + + private: LLVMContext *ctx_; Builder* builder_; Module *mod_; + std::map> extern_lib_map_; + analysis::axes *a_axes_; analysis::swizzle *swizzle_; std::map axes_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 74028f822..8eb1c2ce3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -169,6 +169,12 @@ public: // Utilities value *create_clock(); value *create_globaltimer(); + // Extern instruction + value *create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 3fa008606..4e60d3444 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -154,6 +154,8 @@ enum value_id_t: unsigned { INST_COS, INST_SIN, INST_LOG, + // extern + INST_EXTERN_ELEMENTWISE, // array arithmetic INST_TRANS, INST_REDUCE, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 402208a8b..1bad86c33 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -1097,7 +1097,28 @@ public: static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); }; +class extern_elementwise_inst : public instruction { + extern_elementwise_inst(context &ctx, const std::vector &args, + type *dst_ty, const std::string &lib_name, + const std::string &extern_lib_path, + const std::string &symbol_name, instruction *next); + std::string repr_impl() const { return "extern_elementwise"; } + _TRITON_DEFINE_CLONE(extern_elementwise_inst) + _TRITON_DEFINE_ACCEPT(extern_elementwise_inst) + public: + static extern_elementwise_inst *create( + context &ctx, const std::vector &args, type *dst_ty, + const std::string &lib_name = "", const std::string &lib_path = "", + const std::string &symbol_name = "", instruction *next = nullptr); + + const std::string &get_lib_name() const { return lib_name_; } + const std::string &get_lib_path() const { return lib_path_; } + + private: + std::string lib_name_ = ""; + std::string lib_path_ = ""; +}; } } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 774f2e172..5f84f414f 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -84,6 +84,8 @@ class prefetch_s_inst; class clock_inst; class globaltimer_inst; +class extern_elementwise_inst; + class make_range_sta; class undef_value; class constant_int; @@ -177,6 +179,8 @@ public: virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_fp(constant_fp*) = 0; virtual void visit_alloc_const(alloc_const*) = 0; + + virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0; }; } diff --git a/lib/codegen/extern_lib.cc b/lib/codegen/extern_lib.cc new file mode 100644 index 000000000..0a1f165ea --- /dev/null +++ b/lib/codegen/extern_lib.cc @@ -0,0 +1,63 @@ +#include "triton/codegen/extern_lib.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Type.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "triton/codegen/pass.h" + +namespace triton { + +namespace codegen { + +std::unique_ptr ExternLib::load(llvm::LLVMContext& ctx) { + llvm::SMDiagnostic err; + auto mod = llvm::parseIRFile(this->path_, err, ctx); + if (!mod) { + throw std::runtime_error("Failed to load extern lib " + this->name_ + + " at " + this->path_); + } + return mod; +} + +void ExternLib::link(std::unique_ptr& llvm, + std::unique_ptr& mod) { + // Set triple and data layout to match the target module + mod->setTargetTriple(llvm->getTargetTriple()); + mod->setDataLayout(llvm->getDataLayout()); + if (llvm::Linker::linkModules(*llvm, std::move(mod))) { + throw std::runtime_error("Failed to link extern lib " + this->name_ + + " at " + this->path_); + } +} + +void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr& llvm) { + // Add nvvm reflect flags to llvm module + // https://llvm.org/docs/LangRef.html#module-flags-metadata + // i32 4: Override the other module. + // i32 1: Emit an error + // If both modules specify Override, but the values differ, an error + // will be emitted. + llvm::Type* I32 = llvm::Type::getInt32Ty(ctx); + llvm::Metadata* md_four = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); + llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); + llvm::Metadata* md_one = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); + llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one}); + llvm->addModuleFlag(reflect); +} + +std::unique_ptr create_extern_lib(const std::string& lib_name, + const std::string& lib_path) { + if (lib_name == "libdevice") { + return std::make_unique(lib_name, lib_path); + } else { + throw std::runtime_error("Unknown external library: " + lib_name); + } +} + +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 412e2f4c8..645f10978 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -1,4 +1,14 @@ #include "triton/codegen/pass.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/axes.h" @@ -9,24 +19,66 @@ #include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/disassociate.h" +#include "triton/codegen/transform/inline.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/prefetch.h" -#include "triton/codegen/transform/inline.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/print.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" + namespace triton { namespace codegen { +static void link_extern_libs(const ExternLibMap& user_extern_lib_map, + const ExternLibMap& target_extern_lib_map, + ir::module& ir, llvm::LLVMContext& ctx, + std::unique_ptr& llvm) { + for (const auto& iter : target_extern_lib_map) { + auto &lib_name = iter.first; + if (user_extern_lib_map.count(lib_name) != 0 && + user_extern_lib_map.at(lib_name)->path() != "") { + // If the user specified a path for this library, use it. + user_extern_lib_map.at(lib_name)->install(ctx, llvm); + } else { + // Otherwise, use the default path. + iter.second->install(ctx, llvm); + } + } + + std::set function_names; + for (auto& func : ir.get_function_list()) { + function_names.insert(func->get_name()); + } + llvm::legacy::PassManager pass; + pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool { + if (function_names.count(v.getName()) != 0) { + // Preserve global functions + return true; + } + // Internalize all device functions + return false; + })); + + llvm::legacy::PassManager pm; + pm.add(llvm::createVerifierPass()); + pm.run(*llvm); + + llvm::PassManagerBuilder builder; + builder.OptLevel = 3; + builder.SizeLevel = 0; + builder.populateModulePassManager(pass); + + pass.run(*llvm); +} + // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, - int cc, int num_warps, int num_stages, int& shared_static) { +std::unique_ptr add_passes_to_emit_bin( + ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target, + int num_warps, int num_stages, int& shared_static, + const ExternLibMap& extern_lib_map) { // generate llvm code std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); @@ -47,8 +99,10 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::peephole peephole(target, &layouts); codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); codegen::transform::prefetch prefetch_s(target); - codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); + codegen::transform::membar barriers(&liveness, &layouts, &allocation, + &prefetch_s, target); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, + target, num_warps); // run passes inliner.run(ir); dce.run(ir); @@ -56,7 +110,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC peephole.run(ir); dce.run(ir); pipeline.run(ir); - dce.run(ir); + dce.run(ir); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -64,8 +118,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC layouts.run(ir); peephole.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); align.run(ir); axes.run(ir); layouts.run(ir); @@ -73,8 +126,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC dce.run(ir); align.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); dce.run(ir); align.run(ir); axes.run(ir); @@ -97,8 +149,15 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); + + if (isel.get_extern_lib_map().size() > 0) { + // If there's any extern lib calls, + // we need to link them in. + link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm); + } + return llvm; } -} // namespace codegen -} // namespace triton +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e2303b990..b30283ced 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){ for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); } - } +} /** * \brief Code Generation for `umulhi` @@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) { call(iasm); } +/** + * \brief Code Generation for `extern_elementwise` + */ +void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) { + std::vector operand_types; + for (size_t j = 0; j < i->get_num_operands(); j++) { + operand_types.push_back( + cvt(i->get_operand(j)->get_type()->get_scalar_ty())); + } + Type *ret_type = cvt(i->get_type()->get_scalar_ty()); + FunctionType *FT = + FunctionType::get(ret_type, std::move(operand_types), false); + Function *F = llvm::cast( + mod_->getOrInsertFunction(i->get_name(), FT).getCallee()); + for (auto idx : idxs_.at(i)) { + std::vector args; + for (size_t j = 0; j < i->get_num_operands(); j++) { + args.emplace_back(vals_[i->get_operand(j)][idx]); + } + vals_[i][idx] = call(F, std::move(args)); + } + add_extern_lib(i->get_lib_name(), i->get_lib_path()); +} + //void generator::visit_make_range_dyn(ir::make_range_dyn* x) { // for(indices_t idx: idxs_.at(x)){ // assert(idx.size() == 1); @@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) { visit_function(fn); } +void generator::add_extern_lib(const std::string &lib_name, + const std::string &lib_path) { + if (extern_lib_map_.count(lib_name) == 0) { + extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path); + } else if (extern_lib_map_.at(lib_name)->path() != lib_path) { + throw std::runtime_error("A library has multiple paths (1) " + lib_path + + " (2) " + extern_lib_map_.at(lib_name)->path()); + } +} -} -} +} // namespace codegen +} // namespace triton diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 92a6b75de..c4a13b806 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -358,8 +358,5 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) { return ret; } - - -} -} - +} // namespace driver +} // namespace triton diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 510994fd8..120b575cf 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -379,6 +379,19 @@ value *builder::create_globaltimer() { return insert(globaltimer_inst::create(ctx_)); } +//===----------------------------------------------------------------------===// +// externs +//===----------------------------------------------------------------------===// + +value *builder::create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty) { + return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name, + lib_path, symbol_name)); +} + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index dbee5e0ee..7831e1650 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -988,6 +988,28 @@ globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name return new globaltimer_inst(ctx, name, next); } +// extern elementwise +extern_elementwise_inst::extern_elementwise_inst( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) + : instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), symbol_name, + next), + lib_name_(lib_name), + lib_path_(lib_path) { + for (size_t i = 0; i < args.size(); i++) { + set_operand(i, args[i]); + } +} + +extern_elementwise_inst *extern_elementwise_inst::create( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) { + return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path, + symbol_name, next); +} + // clock clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next) : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { } diff --git a/python/setup.py b/python/setup.py index 7ed6ab444..6c136b6c7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -98,7 +98,7 @@ class CMakeBuild(build_ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) # python directories - python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] + python_include_dirs = [distutils.sysconfig.get_python_inc()] cmake_args = [ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DBUILD_TUTORIALS=OFF", diff --git a/python/src/triton.cc b/python/src/triton.cc index 4e1849733..fcebeeb5f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,5 +1,6 @@ #include "triton/codegen/pass.h" #include "triton/codegen/target.h" +#include "triton/codegen/extern_lib.h" #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" @@ -19,7 +20,6 @@ #include #include #include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" namespace py = pybind11; @@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){ // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, - int num_warps, int num_stages) { + int num_warps, int num_stages, py::dict& extern_libs) { size_t len = PyList_Size(args.ptr()); params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; @@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f throw std::runtime_error(err_msg); } params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); + + for (auto item : extern_libs) { + cache_key += "-" + item.first.cast(); + cache_key += "_" + item.second.cast(); + } } // @@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) { // cache key m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, - py::function add_to_cache, py::object grid){ + py::dict extern_libs, py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); long _num_stages = PyLong_AsLong(num_stages.ptr()); @@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) { std::string params; size_t params_size; py::dict constants; - parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, + params_size, constants, _num_warps, _num_stages, extern_libs); // get cached binary py::str key(cache_key); py::bool_ noop = false; if(!bin_cache.contains(key)) { - noop = add_to_cache(key, args, device, num_warps, num_stages); + noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs); } if (noop) return (py::object)py::none(); @@ -467,11 +473,10 @@ std::tuple hip_load_binary(const std::st // --------------------------------------- // CUDA -std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ - - int n_shared_bytes; +std::tuple cu_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { py::gil_scoped_release allow_threads; llvm::LLVMContext ctx; // device properties @@ -483,7 +488,9 @@ std::tuple cu_compile_ttir(const std::string& name, std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -502,14 +509,16 @@ std::tuple cu_compile_ttir(const std::string& name, } // HIP -std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ +std::tuple hip_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { llvm::LLVMContext ctx; // Triton-IR -> NVPTX LLVM-IR triton::codegen::amd_cl_target target; int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -523,7 +532,9 @@ std::tuple hip_compile_ttir(const std::string& name void init_triton_codegen(py::module &&m) { m.def( - "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { + "compile_ttir", + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, + int num_stages, py::dict& extern_libs) { std::string name = ir.get_function_list()[0]->get_name(); // record asm as we generate asm_map_t asm_map; @@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) { ir.print(ttir); asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; + // construct extern lib map + triton::codegen::ExternLibMap extern_lib_map; + for (auto item : extern_libs) { + auto name = item.first.cast(); + auto path = item.second.cast(); + extern_lib_map.emplace( + name, triton::codegen::create_extern_lib(name, path)); + } if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); + return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - }, py::return_value_policy::take_ownership); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + }, + py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::gil_scoped_release allow_threads; if(backend == CUDA) @@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) { // Utilities .def("create_clock", &ir::builder::create_clock, ret::reference) .def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference) - + // Extern instruction + .def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference) // Built-in instruction .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d032d1e39..cb2cb9c33 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1300,3 +1300,49 @@ def test_num_warps_pow2(): _kernel[(1,)](dst=dst, num_warps=1) _kernel[(1,)](dst=dst, num_warps=2) _kernel[(1,)](dst=dst, num_warps=4) + +# ------------- +# test extern +# ------------- + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', ''), + ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float64', 'libdevice.norm4d', '')]) +def test_libdevice(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + + if expr == 'libdevice.ffs': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) + y_ref = np.zeros(shape, dtype=x.dtype) + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + elif expr == 'libdevice.pow': + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref = np.power(x, x) + elif expr == 'libdevice.norm4d': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) + y_ref = np.sqrt(4 * np.power(x, 2)) + + x_tri = to_triton(x) + # triton result + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + if expr == 'libdevice.ffs': + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + else: + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 30a79bcc9..3951d8b6b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,7 +689,7 @@ class CodeGenerator(ast.NodeVisitor): ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type) return ret # built-in function - if sys.modules[fn.__module__] is triton.language.core: + if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction): ret = fn(*args, _builder=self.builder, **kws) if fn in self.value_constructor.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg @@ -933,7 +933,7 @@ class Kernel: self.fn = fn self.cache_key = {} - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -953,9 +953,10 @@ class Kernel: constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) + return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, + extern_libs=extern_libs, is_manual_warmup=False) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs): assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -985,7 +986,7 @@ class Kernel: cache_key = self.cache_key[device] stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, - device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, + device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache, grid) @@ -1242,7 +1243,7 @@ class JITFunction: def warmup(self, compile): return self._warmup(**compile, is_manual_warmup=True) - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): + def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup): hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory @@ -1264,7 +1265,7 @@ class JITFunction: with open(bin_cache_path, 'rb') as f: binary = pickle.load(f)["binary"] - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) + compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs) if JITFunction.cache_hook is not None: name = self.__name__ info = key.split('-')[-3:] @@ -1293,7 +1294,7 @@ class JITFunction: self.bin_cache[key] = LoadedBinary(device, binary) return False - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): + def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs): # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -1316,7 +1317,7 @@ class JITFunction: backend = _triton.runtime.backend.CUDA else: backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs) max_shared_memory = _triton.runtime.max_shared_memory(backend, device) if shared_mem > max_shared_memory: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0b04465eb..6b0058dd5 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F401 -from . import core, random +from . import core, extern, libdevice, random from .core import * from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index cc0db5566..4197a3333 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -248,8 +248,10 @@ class block_type(dtype): # while tensor's shape is a list of constexpr self.shape = shape self.numel = 1 - for s in self.shape: - self.numel *= s + for i, s in enumerate(self.shape): + if isinstance(s, constexpr): + self.shape[i] = s.value + self.numel *= self.shape[i] self.name = self.__str__() diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py new file mode 100644 index 000000000..a306a2e9a --- /dev/null +++ b/python/triton/language/extern.py @@ -0,0 +1,107 @@ +from __future__ import annotations # remove after python 3.11 + +from . import core, semantic + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None): + ''' + Dispatch a function to a library + + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, core.tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type + return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type) + + +def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None): + ''' + Dispatch an elementwise function to a library + + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param _builder: the builder + + :return: the return value of the function + ''' + dispatch_args = args.copy() + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape + func = getattr(_builder, "create_extern_elementwise") + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) + + +class ExternalFunction: + ''' + A wrapper for external functions + ''' + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + if '_builder' not in kwargs or \ + kwargs['_builder'] is None: + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + return self.fn(*args, **kwargs) + + +def extern(fn): + ''' + A decorator for external functions + ''' + return ExternalFunction(fn) diff --git a/python/triton/language/libdevice.10.bc b/python/triton/language/libdevice.10.bc new file mode 100644 index 0000000000000000000000000000000000000000..ef3ae8d81946e5401289886e62e415d6e10783e8 GIT binary patch literal 469572 zcmeEv30xCb*Y{)(5HLW{009S48x?C%5Zo{UR4gh{(Yl);Qe4oexYepz5MqlOms+ga z1|QqnR$E-Kw2!t4pi+w#0e4hTs1*V{Ewhh5rIW z=xS9NcOYNfwZ|G(mZ%+1=pW2|D@Bj@W5=%Xh~;{Bl|&1riSAvE<^+obkLz##0Q2t} zTc6-%@ornHRt+`e)|q~IldI~_A1cK#F$=>2&xbZXgo`D{mg6~@c{w_)8s6~%0pzQw84j2_^Eu^Wo_^oEVw>jb}%I9a=z@0O;(#9RBVu8b0BeE z&TLFLe?d-ng}Z^%uC3bhh0?#|CBG&+k?*MdzPz$i}QmH z=X=)H1y$OEuI9_?^F0q6f-3Wa&e(#^+k?*NgARwt%2{|#UCT(B_RG+UFCnJ&;n*FaB6QFGa3^&zq| zP*g1Vd`M7@A^0%V*D5;|5_Bft6B5+qdqRg6+htfW>I-Q9a_C&tp`K?@je?kcs>{(q zv!rw8=ySTK=jg10%%*{H*$`Z;m7O38e;x#BWmoHhpesxBgUjtQBv%t$scLsR@zQJfUCA3qe(- zIdh4^BQ28(US9x%fir8HU=Fk}C|zDoW(5?i0mFZ?2VDkLf~I&}7n@&46*nzhAeduC z!+zcZ)Wwwg=smM;p+@E48F)*~)?GXZ(_q*@up>fbpNyJVe2_5PjqN>1{HCT@BP+|o zH;CoF#>ufZ_jS#mwHxP!&kQWFN0Du8>4L&_5s`plqvH23OWFb9Qz%e^>|h z0M1THoz*bz%?h1T(J#6d-_S-U6A$(e@bW7AQ0Okp8qQtDuk66iVc(SGs|;^k7Fye2 zv(#$uXnlG*S0C*@v{aiKKcHW2qjhT!qgC3kb!&G)VrjCw7rJJ5s+gyeIn^O{S7J_a3eQ>Eo-_<}KNucVTymi!=Exe|t+lT!ym4 z6}pCgX_M2V)A2oBBO=~dJthR|vE95w=-IlyyI8@q2n21T%sOuW=xBNTIlG;8Iw*9o zce_C`%OcjU$TFFxZSAln>)S3Z1$ook>U60F_K%)zZVmelzFE|HjZ;(WKGt3NdRv|p;*rU{weHX;JL6Y~w5Um#du z6}(Zv=?NpdASV-M5n)dEXu*PQVCx05@^U;iYOf57*F2@*b&Vjif`1og7%Lb;L58Nq z99;0aLNL36e}niC-9j4cuS%hR*I}vLQ+6(=LpzwJ;9svit^9+w7zP7LBN#6Fa2*%J z;GcNK@PJDJ+aeO8hE36xaafob6$>t#MZjh&(L1>GA*;BUl6VV$;)30TC`;nX?hQ{+ zH=>0&j&da~>8%7p1aYINU#)lMttpG63(8>#;?V_p&;_-m3kstP($EFj>4HpjK>`k2 zNtA}D(f{j$FukL$*wgwbo!)fE=uP+JK5D^eI(K>>J$W!ZSx^l3AlR4LokfP33VXuVT4OD7EZiJgMR=3NTK6S zQx?isI`0(qAsmbVmFR2fEel{jwj-rzVHu>Ov?059Dn=j7T`{)Okx(Df!XIPdjyqIH zjJnZ>Fbese7d^u?Muy!mwh|*RD8E0wK@Kq*B$Q9@Wp55$cg#>+s>Gs&rHt~YFvNvp zY$YiT3__BCRyAfJgXj&?lipUD^k)2#-oj7r=w}sF0ZQ#C9Srz%jO)CM3y#3pN{qaq zk)>{bTKTLvs1Hs zGj*0;qdbOvC}#}&=QKx?7{fmNKVgo*dedsnoDTjA*6(@6J>xiJ&OT3Tt{2=p{{?Yz zjKL7X7!0m+D}BWN-*1#G49#*y4qT^pB9lwAOHAazG`S^=#?&$zvw&gYPcba~bc`;p zJH}RGF78(q3hZy26F={u9QM)2W zW7aUt)C(4;W7IC1Ub_}DdSeQ_=yp3)K%b;mGiq1H=(~DG-_6FjddzXf1^2+%O2-W} zCPUl+7jg8dYBi(pT*cA*ZZ@NKuHsm@LpWd$FFMo-6)?ror1GyQM>0< zgciZ{+6h|F8`H$-J3C{j)G|h91EVn)FdDNrqcMXq7LJh@pyC$vq2elzK2&NMedj8U zt9N7xeb<{|GhM_nYe&~RJEJjOYe%1e7GO+qY$Zlspm&VM9KdMIe=+**7(?7uM(u7h z8uN8VWA?$=N>UoKnJ-pci`Mj^lEbK738Qv(jM_CYYS+l9-5ZSB#n5Z_g2k<3)b0?Y zc9$5nyUD0sBcpcr8MRx&s9j%r?Ow3BBzaR&Y?hZKbVcxWr<-VhaHP3AhM4qLE}Oon zzv?mpVM<3r!e(N)b~M zO!kt_ij1P$AJ=6veWh8(Xm*CI_ZY>1X$nRoafueTrOSHq0>s1^+l=lAWqiei3QQ1N z&nO`iDlkB7Qb&dj@zY}pM}Z%q?`WO!eFyqeE~FW45We@U{l#3!Ke?gk(-VVTWKs}} zubxsoMrZx)GB5VS??7z~MtfuB{3dAx{CA2Z`v^r2h87}YRbaYtW0LIcvDdY*cP zdeC)YHJBhM7I8^$F^BFGdxX&%Pf^w9DPA0-c=?Rt)i4@xHlqPk84bu#gzsOb_ZO2M zWQzNdA&#kcOmjzXLnc~dKE>2lMuU_x6gdUs>WN-}KXJi*=mUVsgdAe8&tP1V-&1X-uriiWV^Xj;VJy z=;B?QbE`8g5{p5!)iPZh=GJd!(9(f_oc!~wdgiXM z#0%g5s&0Mv?TLT>e&|ljh*jNBJiK-Mbl;VWFV5(<@%?)>?~gs#cx*HOZ157ovQFcj+(dJ$d!-&-fX;|EC20#a|(Zj{CI^qR}>BFV4Q<@3((h{pGj&ckcW$cmMG(j|~=kjXi^>pLNf_)LwX~bB1uATd@J2 z41K4)Fz4mL;+5my#bx^zLdXf1a=OYn-J&@W37p7LobKtIs0EzpQ6XilC{=VNq~8YVM?rc!WcaVmkYO*# za6V+%02y9{1Z9vQ4-%-7Ai;axBjzXtb2Wnbdci`gV3AF*xK^-Kp^Cl+DQh654N|U# zlyUDvU1FeWOCYO1A*=h4;65b4K6Wei9tH`vLV^THum%#eD`dIzhDv$EAY7L@J8j~vi6ax=pj>$I%_qzE&(3ThsP&<&|WA=grcs8$6tm76Ci;N5||-D z$nd~`(RsWv+jwJ(c;m`=<4^G>T;)xKthyD3U^!5i;eqT~a2r_|g2SyAZs)*lRAGn` zZVhldcU$qnw3C3rzt5XnC_l zm$4St#a_4dUfo2o!NvBVqxrHl1&>H(3M}XJDU0iq%jr|TAQz(N76}Cl6CgNI@VbtJ z!q)H{|8j4yYZk9MMT;4V7MB{l8q{9h#HA=s6M`R$wZUnM<-PWqdR4^r*u&|)&(u4M z)2qBeGEXB|lrD%uk(GideU1*IEFnS@qE;JuknMbZ&O~9(EGZE!m}dyP(cpEX*6WVV z>zaa-QicLZ$uLp~){{ZVYN-LjIr%3naXrf7dgKv*i%O!`#YLMrz1Ed<0ipMey{b9A z3Q9zWrsU{G<;>QZAbxWJL~%k8qBY0==UKs3#(bP(2N(o})Up_ikNUlt=^LAuUURpZ>2LpKdIa*R zafrrErdo_ea$hxA`iWj;t?tfkQ-;gXfP2z-g5~C8pVy=34Uyb6Fod7r zL9e;jXnk;*-Uk}g2dM>#{lgO#xl$=6G_w4!Rch(oKp%T8Y5jou2Ku4mV_qCq?`Yw_ zSq?p;9JPP19M9j+L7i#+yypr1JnpZ?Bx7F7Gjwuu=<<_n2{S$GFVi!}hN`2>IAUM~ zW?FlPXHTnoYT7n11|C1=oO;iGRPS*ZbLTpcZEOHTm-&%v8zXFl=Kd(#{l2sM%&YTl z)zv@UP+J8b6E?sx_R2nXi^C`-(BdvwDAW4iuvc>0*x2G)ZmMtpYHU3J*gV&oCV%Pk z$`6Hhe0;z4p4&jP6<6rCf&g{twF7d7f_Vt#3&5i5`Af272gB?NwcXo*quGz#Pepl|3;Y#m}Oe@KcP%|W|E>;GkoV*Z04~ z2I~HFY_RUwKZOl&Mmfy?SUDQv=1Wz|5!OJjB?cfy>dMNyg1&IHZMj# z@4Wc@@;8fU^1Hrx`R6=n^1Hlv`9XB~VJ}|(cz0U;JG^-Lo84&g+rN1E=Y%x*p)X#3 z5M4h0yz-fAAj1QS|DbU7Q@OxARx{HZ{%mIcf0>@S&Qjf{xH<<9@w7GBAJ3jv)r_?E z-CIe>Ef)AHcm^W z-7oIuOnWHi2&7dPha0<1TErl_NXGHuJ zu?BimCDdnH`OZC+LaM6`tnmLx6nci@tnfW&}z2SN@^KWK)!?kAS z|Ci~ZyPg+Q`Pg&HXJEx5)fpGuL($sUBe#qkve5-dK+ z*ZKT*W44AaU-RPSE9mkSFJ3-Imyf-8`4wGh{a?}bh4;UeF5mj%tY2wD$eY$hR43?faUMZ!^-`|NXRl8SC0p zhJSOJu^f830Gn~~POtr__~d)fy3`!4y0A^VOlE$(9D z8&u6mYk#j9`8Ff1eRDJNZAM!Are@^(>}ebA8=pDfMpZM?+TU$PzRgH$-_VSFn~~O@ z*Nl9hn)dnT=eFR3$XB!{=Lg3PF+Z@}<8dJhTP4?le~k~E_9-7i=yKVsT_@@Z+>rmvs#}(j{atJrEJC&rx5YSY+2Y#OE*9_!YErMLt`(qKEzW%e6U| zR<5p3ESH7WUf)D#k8dnpHpoFlKCQjl(faWbzn)}Pu6M1X%TYK~4{vp@X}lwU4N31$0}hYx;uQA$LVj!siyqsXYzreDUsU{r`(1gR4Pz_HR)n}rIa2q>4T}!{056_jVNdqtL@)28UO4^UAQ@ z+=EKq0C6nkCv6z=oi45dDqLcK3iCm%l0PF$w43{_*?_XN7y{dshN-$jlxU6~^@lH% zedHKLOa(upo{WPA+}V1JSzlS1Hr*k-)!nAU@eZpuL>YLIN&jD=8StK+KdtR zN{p|~D`idNtEZ6~U&RHi)m#qcTeb~3sEgA>J8*_l9e-2UEm7lW#Y3S-pcSDJvecqL z@n8jPP=}Q)+8>2%&1=qD*Cr6P=1|R&m@u5E@@Qqk_mb0TtdQQrQ$K?7u`%6VTLmY# z(*4SqID6B(3a&m}E7YUiFeH9@%ox^xpEkzG3d+SQ4ygy=*r6doy?EG*9>agB_t)CK z^3gw}_fgPxj~^?}m`P%+F!obC)fR;q$brC;7F31Th}&uv-zg=q!H<}OK8q;#kvakr zzVr!C)QFOG+z<)5kp*!|kF{TUX)5t3Gz{+e_kMUGnF^LaJ%tRGM|uH64C#gVQ>Ho) zxSHfzRBnb7xQ_RcC6epdujjxZk8mwidOMYVHHjyllpgJMguPjg{Pj=FpJsm_*-t{{ zLCSQ47f!{;_^Fs+ zI{XzAe{sPg&22-E0^(o!9oC87)u!koL+>I}U|tFQK#0`Bk`=&T74g>~yJk2hxNPI2 ze@%O!u0}viN%Zq8tp6COfAxQ`{teGjf7?G; z|JR}Z5;W62P5C)M)yT_9+3r)Z@Z7leoChxWK<0t5C7n>^2e{$M478 zd-$*~@eX~)7Kjf?SwbT&bSDxwM)k*xc-a!-rU4dEV84ldjyJ18$CE(EWhx3AD3ZR{ z#fg{E69BJ| z`1+Jhx&psZFo%pvReYz{+5H4#?H~A!Tf{oFURO$PrqrtjF0Kb<16K$I*qZfU!lTdq7{$BO~CVqk`XVnY15!!%;ZfW(1`A43K@P(h zT2(##4k!{Ea-}kemSY91TdNiC=6(-oFv2)I4ePgAC802Sq@SR~(=J13NWz2dLNHnC z`PcS{rYI8e6^_RXeB-v+!weBnZ;N;Z0gqYItLuWG2^NR}fcpATf=bt}0X%t}Dn;+k z?%|1Q&tzR`TQw!HZy^X&wFZHvxlVz1+X~`jSP zLFm7b&1t2uOeL9cdY4uC8={)v0+s@tGbB+zHy&?s0 z1FbQg{SKNLJHS}$K30k8lUmzA2 ztxX!Mvn%f1>5~An`BFP}pmijquUg_$i4d?td2_%D0K}yem)%CWiPZOpUV^=GKjwrG zh?yf4XC9|dR$Q+NQ(!f7R=vhLCb(te-=fxm_O+m0Xa=>2lXdP6?Bd`yh&d6>ljweI zw(?zgAucVB5oa(^WkBC=h5gC`o)l{*8X9jU#}j@$SsgEC?jt6iM}M?9ZLc8#-|l&% zb&W$dN=L=D3(;Fi>pPeGNofRf(x|W z%6%Lf3O8|Z!|@Y@3mqz-xhplUgK zJJ-j5H@D1zUw_Li_rZ}?2@g5!2P;8a;%R6F8nxVqR5OFjdIYSG^TsorFh%ig zDA59Ow`qiVK?`kGaH%M;+$VwK=gx9F1`n$96)W+{Xp{~e~z8d>3Nm2bxn0MtMVj<>n8fi_Km*TXfauBIuF|&P+Qy+eJi%`ncL!eD3)5fAz z?Wgee!d`eAv4nSuN|(^>&!Lw=Tko0uTR}K5ny(P&20B;pDKATm%_y`=vq&x}0dz%~8yUqgdiED9{K>AhJ6e&R841+8?`(dhvk_#EZq>Mw2&pHppw z%Kz9lg3qaO-xNbmaZQ_}VceHE#{ESY_vD!c#r{P5lxCVb$sK#WAKyOtPHyLM?epI* zpZ#x^ubybX>0%G)2;^LEYCp7=BF0!em@IazbF|F@l7WC!1Wx@zcCE9ZmqUpgh8zZn z0h%ADu$?Ij87S6;3GjI$an z-z@HdK|K*lk{R=1D@++YE0a+KRYlS%2;>Lw6}N>VW_}nBbGZY2y_0l4%#PTp<PE&6#Yz1lJtbC|L1n-?$VBrFwSm-1M z>8WHH*!npLlvMlGNKv{5NEbJmH(wwa%L+U8RQI=x4c(-V5D^b+C*w*lFT&rqX1FPI8lM|F<%P-0HI-}!s z+{6z5uGRpxx47ryXbUs=i!x=g{&=cB7yJ@19*CJt z_v$Wa-HKRf-AxA6TqT{jyG7|1zRGX66;)Lg063T(pm+@`!v}0^;ji>B_t|b&dVJ(z z2}{_bcAJU|Mj`!2+!B9+64xgME@T<7B^g(FTz2E$ovp;e+6#Mue9X)!LT0-AL@?7c zis9`q2rTWH+>)!G(voZMqeO&Cn@A-=zIBJM?sP_|%=uB_j63kq&j%r^!i?KI)5I$S z6x94NanJ@tTndjr3_TJb?^$sFz`A#;Mabt>@t>4}<%x$(bRxRRMhkLxMl$wj9E%Bpa&j<;q^plIlSXowx<2*_&j$tFf znEG;GC-f5erf!ELOf#~9#S2urhZ1Uvg<@OP50KG>t97Q@$T}cZ7qbq?)RL`gJX1Zp z>D#`U>KR+r;&#)HZt8*l#|b{@@7`}HhcLPDv>NO$x*8v{GEtTso={ju*y0k;jV)oo z+WonutSk?&)S{4Dn`XT#+N9wK-;Y*f;+LZ>b=kEX(`bH*(2NON$pEH^`5GX@Qtf>4 z%FMiWa(>4cLxRX1E9MEh)`+bv7!J|_1uSlHP@3DDquHs}*jdX${XR3ixvIy=x`1ub znz~VN6W4O(u%dKK2&f|fnG_?h#4|g<42rdjyPNwRtaXClTvNd!3t;s;s#$NC#UADm zvQHI&G+ODYEUHz8nK|N7NQ0tRAPKx_$itiL7Ks~kyra;|B7&>ed8-M>@IYe#oVdu3gNORk6 zaPMbL$cOSP4Q3=RIdm%>-C~Ew=#`2*1HrF9n@Ga)|O(g9)!*vIu@+RZf=@Fo!`hw<4=Y% zG;sRUUK7=@K)X@l1q(_&KT}-pv(>I#78{tZ=Fe8BZyE#vn+!6r-+K-Gwdl3Tz}Mme zqUi8D!!U$AFXL7=SxjLM=+?Nv?Z`x}%7xV`vm?b{7vqi|ZmxAiKybdAy(d-_1?JOqf5c~8C38pV7!)ivIoLw zGzHZ_sdVu5o3Ec@4TF|2p^rJi_3*I;J)K+2nkDe7F^$xsyRr#ats@wnIUm z*N;{X(%(Eai*+&iz{Sv*j9+=?;eKzeYTR4b3Li2eEau{cgX8qkBX%~LCvMzACR#FM zEuT#Em0tjv=u-9s3qH6|6m#)p_BhA8s`lxQcb7h-zCD{peRgFp_3r1pQyl4j?LX1+ z?!sQ`vtKSxA;o{Qahu8*alC7z`NG7NcONEn^B&3xTgr*}CN|!4Ct9ASje|8Q3ovlT zEuL}d*lms)%{;*4p^7c%8&?G`GSr@JIZ=&6wOXd5mrL^|pqG>DiI=klK`ms0y}^(&k*}}iE0~T_vf4wgyAkf9W;zL3#C>ek_9{=UKhf0>aB))?rHr!Gy z#LgY*=b`U#&9r7uT(rd$SWwcd$R3?bt%CJVdf7nG%XZ=~E9jbmNM90N#Ok?~mKdd{ zh-gGg6Y2FVAf?ybE~!N8=*4)8>D0fhD2`I_!*DJyxp$<}ofq#piQD6=EnCk#(EfCR zLlR_0s{+wo=A}@fqVxB=_pf*m^6$IFQv3Geo=2dTi*}sY71dw)wN~`{BrZ%5zjVp< zwib?g3)%xc_7K{zoA)iJf~y8z#MemduNEzJ;Ph69OXyhQ>J3AVgG`+OWLm>0nX2BAnTm2B1w}C3tgi;P z&ITwYt4$aMS^Itf`p!bsRNN&BfL^a|UGCGHOi^)NB8MB0c1(Cw7gr!C4&x;A9j+n0 zYP10y%D6ko6%VBRMcNVWySY`4MTmAXx`A588#ApogkZ&kqV{k*0(O$LC)<%F-Q(90 zTKW^z(r*K=CPhnR|4r&nmf|pbWQqngL#i(JGc{y&60%A*MkXrx0uP`TJEXgPr1#ro z9I103dZ5EN-u8k>E4yV!Zg`@_Gg&snF!_w~0ISugj{ui{qKE9ywA^IH1GRadj}XeE z!Ksx4q^PFFc5&N->9JUDd!c>0sguw)9o*ypa$_mp#kq&%W$OcD^IICVoHjV$wy6-a zhpFbRGEV=}Akl?LyW{LB*iXB7oe#IHW%EwY=5%a`RodLFs-Ws*Xk7Bh>(NS6OCYI**mEp~jI-4eNTeNV<*;f(- z+YLdVWzV(wu(w?d`YdaKtt(qPqlmLs2pHuTkDuCR<+W>sS=rw0tGetq?O-ZmiHdP~ zYH*SG_1d2}YidUu71C|dTt#ipOzTCZ4f_KUb43NeaF=bvTi@U>lZTx02-wbEV)2Nc zQKb5Exgu|}*c2{2+!`{L!jv;a%n2nY)r;1(Z=p%mcxm}tQqyQ-SIAe`6<_t5)fif4 zi+Io5#Ac}!7{~kP<$$4HLKC{^C@u|Gh{aq^*V+%XdF_-YZw?0%;hYI=%%j;6(rvZE zZsMa>uiF)M+4?YeB*e2CX8uu-eI$ID-qsV3l!nVu4hr?oUD6rCqivuWPgy;U-X@fj z($vs3yCAQf&g8uBDcp(f4el#7oP5up3@w6Y2v4D>&TG8jDWZ@j>7=&+b#zAU<#iDX z8I43lC_mqjo|>{=>D6E4#pT*D&x*I&Z`wf4y({&c5U(`1s{k^*_#LBFUKOccKNX|Q zueK3=7&v`Pdn$8I(5H}j?aPpPzSlKp=H@OX0cmbGG^MplrC3yff4%apymm#XoZC2d zub(*HkX0kfYC+0N&aA?JMp<=1@MS-C2&T?NX@#tCWBU*wJo zFFHCX0^nFBjFjv?wI(XD?VD^*=+NJr?u9QC_t1@Zm9?f6&R5VGcqE9lw}=m+T3@%j z8d`rOj0OJ(n8exoKAkGJM5MXJY8|z*mWm1$D@GgD(0k^-Vx@MyQUa|NBC*Ce&br(H zbJRN;1}9v&sceobJhLIi=$d7f)G=^s?l5* zAVV(`H}0^9--CH|&&PqX4e&wI*jd0@-||8`?FApS$Xn6a@6!>%p zL3IB+AKo4b{6y6TWoJ%0yG_L=-%b41|GtBP>)X$Kb3m!m&G%vxaEM9ymFMnAI`%7I zxBE}C#K!<;xL7qZL)n-)c{&A>TBXXy~94h5>=-T{_9)TdH{Xh4*#>uG8|$K!~gd4{7Vj6FwvsK?db z3U4ZECCVi+(cbbCNKljLy{pD0IWdwQGl8J2Kg_G7OIgMUAMe5QmZ)5E8GZvq_H_BM z3d)i44R26t`z-!tHAr{AW2F`hVB|{Y=Qesm*)?kX2Mrl?IU@x|}Xu)1^Gdvih#A3x0(R#QN zt%nQ1_n<#MM#C*uVoM7(ET4LXdy3ddgtf>aW4N?9hNb*3CIf%v|6;QL#bp1B$^KDH z=5vrkp@IJ8F98}TK)A6R5MHG=ghI-!?+8=YW1js*B7h8f(d!?})7&ZmCN|d<%f1#?)O%A*a&sq@^kA zZICmGq)LCr52AYj14hrR?(&ynGK)ge(NhadJ>V%4vZQYBoE>QV0u?ga?1U`%{`4&y zmIDhy;k1O$-_@lfoCdkEy|SY7+F4Db**yVsaC@lN_KJ#2c| zrii!)Pkrzl_NdBhCtOAq z#kbpuiv&)Cu6>}ymW^f)h1SYa>wm{hI{d$j$#y&oCJW3&Knw*3tTHxW3r7* zOqNDqvW)~L^CvLbWjZDs`#hLzllnIYUhRtulevRRyO7cE=bjC#{T)m;PVEBJ)-izE z|8z{|2e=6^SsB7)r3jM^A~2bl!erdOV>M!v7+bj4hH=|S8>RykyI%<082)ujaV)m= zBs`))`0Ud{n=s3OJ(vnEh%JyQTv{B%mJ@hvE5VTsg3li1$okB^jbg3--{;7FJ~DF> z_j>DH(|`fHRygLn^HMwTVSgYdP-ujpF~z3^Lw+rmvhHPvUtK|RWKM2O60pw6jcu18 zZY+(-joqhnWA`7&jj0AX1Gk#EF;-8{X#^KWMKksEo{m4ljcJC@1a7R9$&FPEDGQ`= zW2#>5bdPXj_94fB8`ChkF+=ajIdpF9R(-{z+?eKV;KuC13~uZuL5DfGv75QvMT*=_Z zS^+m^0B-C>9z1Z=T$6*FH#&+Vc_~0VT!rGrqtDnY=VNKkaaWTn_?WDM|PaeaK zS)R^~xwHKpEEc}=J(&HMAU>+RR^*)h&vjgGXp5kigGbSPF~}$f$U{+N)O?R0CMz@4az^Jl=j?=c`-8`=Ahq0%bCfJ!qSLsh3yX|>;E^8pe& zskGW}gf9{IPAaX{vAS%Gq|zP$uh|F`otM{r+g0>>lhfm>$!PjFqCQa)mLxW0+n0FC zdK@ukEXlgQ@}33W+w@%kfSM7&6E7PobnjvTUUnJcWqBPW*`DUJ3clu)zF>t0Q^!;0nGAqHWqU=6TGPDfcx%BK|VGI8u@B>Pr4B|!exC6c;12=NCeUsE%9K9F}g1I3_~ z!Jx%lZlM>CC-+Q9LV-8v*3rp3Q+0BfQv&mpgB7vQw1S z@GR-lc+omV$dF$<#aD~pB8ad7Yn^cp3R^9zf$&D?dXSnq^Ag{f1e}`>$>#+wHW-xz zS5$SEfPR66z{5pmY~{NForc(_Qhrv_$-N&0v|R%X6qKL`!iksCW~dW@)QhsBsvWq9 ztz8jOMMT0vSJ#&hKnXlv*bY5=A$yulDN)2LyWKNSyzmGr?Bv3n5e(kT zE*~PMF6B?2ureV&;`CRnRtoQsq2Cb&&M~^y+`MX2larGq@KXOU-v~YEVB0cUii^AT zBLmUU{;hK>Tu@~49|<=?yHB*r5C;b>Fgd_M%`M6JgreqVUML{~ejVY%&8tUzv8;?m+67D2?SOd+r$+K&TR5nO~qBb%`w%@G6J5Y^xEnY_c)N!+jc z9&Rfog7q(65ZWLT|FfuXR(S7OTK3s~1qG_?6Fx1xw!+&oQ)+60UFkbcTb%`5vqT2iRC}D~kJ3Zbfw>w>z<~iwu7+OiqPx4RQ-0B`$YeRE=0qAuiGlSJA5k zf&wTX-EA)w<%&!g_@fnq-2G8N0Sam<4`V^3*>WShGgXLAllkCkB;sF?tC5s@F;z5) z7PW7#QdyiaYU^6I=vgD)g9@NbrlQ>N;TeDf_n{Bthf$TW44F~^LBBzJXISpV_-u@>P ziF#^FKv~r+@7>%}(6*zBM@$v_K-q4qq=TZ$R5FL>6ys0&0l8D<2r_=?jP(CKbwNp( zy}P4q)}%uN%6+DTM|HK^8OmooW(D|B)d(xv=2GJAeLVxE^ zz?R%e@IxtXDcC{>vi-I?OcYfSD~1R;R)Dj1^V@kq?K-2{L9m>{;h7c3ebE5u^R=e? zBFqEbL*q}zA?$^PBIAnvm6IHC@cm$nw};ra%*CPSDQoUidEXfNr@MD#=|!;G8)gCr z(AaJe;lFX3(@`@JN#KvXH5-u+XWv42d+zKTcFVmzE8y+svuNMKT~9S-fE*r%4<|tI zTqGpeKNdDtZ&ERKD4Z+3Lj{76HZu8-GVM5*9QLlNIp~rs^N+HHuE~m4Z7Op~1}pI$ ztt8--y4`1#OQtciUm`a>IJ`Yv)$1X#QOIZ z+Tvgv*I8kPqnh`Qv4~1pT;tIC>Z}m#s0LQPlN7M6VG_LIKwjw5;uw}d*flaX7KM*A z#l|-6GrBlg!6qUighp5b!e*lTE4;UMdgaYCQp1f^UETQ9A`vsvmyzWiw^!{`^e;F$ zwezA$-1@+S)~Nmsu&8_hd5rnAu+bw^%_XfU3bde%!5u18miWCjQi zupz)Nt7#QnjaI=^(AsQ}FQsf|-*A3On4z0P-SHtjf7pn06f{i{JA80&M=|uOL4{y( z8c>)>gOaL<{daL`m@2|ik%E|h;fZ3jl_*p>xdxf7OeCsOQtqSP&8=4QKJr)-LG4Yd zPeyx_=2PoLQt@qT!>hpzMmoxZB?+{!uqIJ4EdHPvYKnH8oO(<0@`p`1P zX^8rqRt<;esO?7fKIaUR&^^>C)d(SeM<7PmZ&SgSoL5f_ggljCf6iG13&8#myB(xk zN1!$AN2qM)TT$(#U#Iol&DBs%jtx81uiQtEEYgDFSI;M+vY~t$1WTOxl#b{G`FxJ@ z(c)nA91$_E_@76?y;N_>=e|)IW|yMgk_uC(c|+3_ziUzXJLjLyk<&GbM2+*|&i#;) z)j%MfOvR7&uMV&pq20ekB42BoE>shL06SaYf!dBg#c5|Nl6Sg)iNb%WqC`wCFuq*w zvkleCYOsKvordh}RBQtZ!$;>NkY_7U`xCxEu=$IN6q4%#M|(t9el8vw7lu6Er_qM2 ziKV3ZDkL{p-7^nG_ApQfmak|vygd@Q2;T0`K;|xR4!r%+4v%GAV7nNm)PoS~AGm}7 zy*GIPbqm4!YU>eCcbQjZxLUZvRW+^_4h-Q2*JNM_54*@)_eYt*HQC-(qn4;v-!ms& zSmu(tK6F000Vb|9xvd47ityh$xI)e6U>_}r_*;O8zdeOG0q8Kz4jw;T$kL7R5Giq^ zeGlRX^(P?N-KhqakQGi?$rH#*9tCr`9yXrKns$L`aSVHxu#W*Tz>%@i;gsGPx-5)~LRgTrHI0=67Qne0Fw$jgYNMBJ*kmxSn~-spCsLZO7{Y-g zJhABMO^#cmh#47$S}#GL1CLSL37806QW@4~?cKmhN`| zWvW~@cV3T3pl|Nn;~W05b93k0lfR`4Pd0uVxUOS3Mg0taL60KBXfqiq$ortH&i zLiV*NdkcRmWKTv1-mZwQ1_l&1q&V)yJx-x}G#CPeQx%Cevy*e1{{l$szMJik{l{XDiW~nH!fH zYO6j>N36{;G)bU6C^moQBoXdv)?tz;ak1yAf0PYOr|>V07f9WQ2%rzi^JC`1WHdLU z4#~N@(3v_h!O8EWUWU{^L--i8y!XnE5G7)XO2|#g8Llx#dso#T90v=IrJWdvJGBPM zLdU5Pb4vicdI*)m2WU#UjADJxj?r=T3h^=reIP;;3mjOogEtV9WGSSAiD(Nl?2&v+ z(bhZ*Q>lpQe^I_el`><|5&WB}hI((qX zhh^yDIjsRq8WrDd=;2MqdxNoSI*m&CI3HF4(}!j6`4}IT>Ty0SmWvOo;c-5!n8*0A zeq2Y3@BXnG-5`sH;=5gZSggnSur&RU+lrjOGBrf7H`Eg1yPZC)0v8`v!Q*^bd18hS zOB1_;>BIWX{rP=Zmgn|i(c`yB|FI8C{CFQ$jf)Qpe}WIIw)C++tT$YISZ_Su zhoyPE59_9j59{XReOM`v^UvTx zsPr!EPl?zbw`F={j~gNA}ZVwkM&;9nN870l=BMP zy&s%czq>fEOn=1k!)eYd(;u&xiF=ng@WUb$2j1koBI35bQuKl#`xsJ!>E-eY&3Y?} zR7z>VjR9GQ+vH7dF!dFtxjjGuU%bA$->qk?{lFEKfLu}7GPThCWK}}Ka0w9_K57rKu6>73 zrn%FH#0OmBILdovU%Da|ScrZOPZTbWs%8d3-#zAQiX1WH)-Bc7LTZ`^Y_Km@)h>@IQ)zQq~T z-Fb>c%I`crax$T#C~I8T@<5UP>BMJAN+&+c?0M1!Z{7%PHOB?5or{?*B3s$vn*?m= zPcprZLuJ;k$}2L2-%Ge&)qR?dh@$r-{6nzC4bDq(#BLn-7kmGYrrMKF1MIP_`1!Q~lr| zmRR-acDuR7u>9a=Vz4&SiHlgeV*5yaKVIGhIF&*;i*+d2uBDbp?FxY(7!M0T5&SBl zWBXYm=Of}|7K#eZ{4Pv|{6r>Yb=?8SD%Spr^F$|&j9NJxq7C+jBR)8_ioGb|b77(H zaumy;4ao6=m<5ZRdJ=Zal|f1;|9U09Z7Oj#6yjg0vPKr7g+6g+L6wuvyA}}vTV87^ zPNKlK4%|WpbmT#91;s{Qm7uc=W}{_n3kj97vf5c~I!ZuvA$Zt1_s^pYA;7?K^e-dZCbCHinX4!!rNT;5}`lYNtLPwdZrZsb9>Ht@BNw$-9tU6v!lo*u}1Og$57AhhzM9Pk*cS4 zSV_Qr@a0IU7TYGV^R2|m3x2XXw~4$NxZge}Gk(<%{72kfsEuh%t!mJ>stypZHHpep zJbaGEokrHG4MDcm#dcsBK32e^sd~QsWA>r*9d*|nwPH>A5o%>aY=D}d<|G@gt?MX; zWxPLCwrkZLC6F!bIoHMe3xOokBQ`+Cua=6Uir77XVVGj&+mEmhod|<)C@cEX!WRah zgJVErB3Rq4}Fyxo)LLHG>7@g-0PuV zFh4209(slO$(`$=t(XlHc_VZH^OLzZLJ2$V@a0z^?zSFrx3y@Y{UbWi>lU)dGm$<1 z0&K8)EZAT@a)E6}-0h%gc0a6GEzVhuzg4MN<*QnoJTxCY_0<@ihzDmcjDrXD+rgJ_ zMaS9@mllV8>F~~%t0vC3LVmSO_1$d;{t)nF`#zb9PKns#Mq2PCIZda{<|GBhG_kwj z`hSpJf~@NOK+^q%%PUl0LSA3!u633z@;G#&pn_7k-W`Bm@`YiGb_^5J=x-NKzBkU6wGuB8lVWTf%7) z5m1i5GseymV{8mz!ofJ{1C_lEoI#&_zxVHpVgU$NIu7k?4CFgbBZCug z9=xd}&+WtI2(OlXqjh-tIYH|@9H%H;AmJln=(bFwj)iDCZ|pV~(BUvJRZgFMQ178l zKKq7br_X+FKjs_$forA7t|9oCNbWxL4v5wwIDzaPaajFtNhNS_J83e{^@YeBhwG_E{6;YfL7ERIuW772K6%8{ zcL0IqBZs9Zosc!rc}#;Xg2G2Y+r@Qnf{%H_Rq0OGeY%BuR5kbjz&QnSk|WnWY1~=8 zI&i>SPkCMns$$`12}t~ytdb@yo`Lxlu_O{34-USP)ez@O5m65bL4HEYU8Weh6o6M2 z0$vDUlHm;akhu`44rhsy&2%wUP#VUH1Vp<=LHYize#pQ6A!;6UegZii+#eo{116$; z{|!m6TOba>aehLLTChk_-c>VNw+J1;bsAbnbMS3A7_>8tH>OYeb5Mwrd=#RYaTwWN z$7%4CV>Tw`JbZek)lN!3lEx+7q{99(C6AtupEd!l7p?j8ixhyh$0bUR~)&k1z;SRie4WJ4A)XoYl1m{!wQ^T-%LbmrY-V`W+gy#U>V6`f(m zTfY}U9M10#t7@c%A#b968QbCP5^zpa%5g6ll;c&pUXJgBq5vHb;?7p5Uh3j3p1Vb* zEyo?)3Fo;@mM!u>w9*lyky{PHepzU3Bk4~Ky^3CDINq01*V{wj$17xjMTQC~bn*kN zpBy0`$VVQx)c7;Rg46rIu*nFoBAYQcP#qiIz0hYjEvo%L)AWkE*( zlZMQ?HW*el#~sl~te)3lFNwoL?B}kc9^H2wopQ1RY3T{4mgZHrZd($@qJjf8!>m@& zQZmRevPLJ!2Q5t;0;K(U%2o~D!Olnb&>y;3YT1X6`5M|x9|vs)$LdjLq`q3er#MXIT=vyVRu`dr z1T0_DQ+U)qWkm@DYM}$P&~baz!;RGyrNv<^=en#>QdfrV;o&T)aFUwUmKwu5-Ui32 zop7uYp@n3RDM!Zj+CyL1S^Vm}Zon!TxoC&-jSvlB1>fow{<}Ipc_40o&{u~2$_#Oy zbm#Da5<%x@=ev4Oug_Ny#>Q#bX8W9_Cic=h5@+apf859}P8TDW{e*uRvM9CZZV2yv%xeAR* zLALeh6BE%gGH?kjH}@YSe~kp;5W>eP{ui!SN-s}SL8wfFW39!3$LkDoe zrNuGq61@6y#}XI@O!cGo^y>B=hb|CQme3JoI3ReAfo1u&BpMEI&*Jp9nR@Tz#1t8N zRG84;a)V6I^0%=9;TS1&^73UH|1t${qI10~esDjiQ40KoaH#`jGudiGN`(e6v#EfhYY6>iz=YUbEMuu1o;Y?)BYloQDLCO=3q8fU6uZ43$N1?tUR;6h!H|t(cZX9%lXdef!qE_12jeUZ{?Avt=*MxU zb?|ht4pocAE}!_$fWtWJtD;=B`-klDDB5B#DBA^e(q`u3mM%}O%_u>d*q+$wQU-6Y zyyFUZM$Z8&FpmIN@Owj0UKZ#mgaiQ3-TxN65V*z%USWM|LF>~!CyZHguTU(@j>u}w zNcdd(DetgY%$|)cTRj|t*~)+i(F_JQ`&|VWgO~$-8ffJqC%}Qcv^a(hATUk_!DYkI zC^VfU;|(BC?b-wRT1t(E0$tMB0_YOugV}*byVMQ=!GEss-nMky31F)gYc79fERoum zznkjeHHm9o^l3M=8RF-&ziSlz^=IneGuery#e)?n(_c|s9W^3vQ}l;Z?}B;{ZqmDk zc&FYu#)T?oC`91+!9aBW2Gpxbr(x>?sdwZ1OmVIvYBU;%j$aL@FUGbk?}`u9okPrG zP{5$5>EmFMe-EkGW`}y&_`l79Lu@5L4Z8b|LkHTLeBr%^@3;UNd~V^VE6yolW`hIB zukZo31T>vp>ri-_b047<61;1MZHPkD2#FK&*4+i@0>d19?NoJ|Jo{KbDuBq4xDYl) zz;6+dj}=u90Bxy+F|LeFHmInkYdX5EBpwz@ZEnCL@x{Q+Q`SQ-S=|UWp*zUj zU>#AAz+0W{>rUB3@vGax7!;xPjcz<;4D`b|5p-Kf7_i!Rl4DLXQ(t0#hapV1`gW?^??IA zeJoDvGDsCq+(S#74GJ)oTb&03e~v>~B*HDR{?b%k9Xf_y6%D?H5`U{o{fz>aWaK6X zJPk(H`5rK`E|U0cWM>NM6ZEGkPVdF*|a0UeCR-8|T5>#t|-CC!~mW{x^yb-f&EGmaqZ$aY_YXZ9xut#75qT4=+5chhP2ifky0!5HLKp-Lh>8Ll6cnc# zR1_yn01-toD&njGQE>pHb!x4e04lavu})aAjg~r8w6ubv#Wq1y+G0y9Di*DAwspda ziq_$K);T#R1NOf6zW4hN-}k%yy|3h)ob2iBz4qGATI*Slc1@8FEgW96Lm2$aw<($8 zvdF&=KWi>&a_{`+WPGuXf-NWuRD;k7m=9OlnpF?4%PnK_P{6$%CfqaqDQOhaOpKv0 zM93YC%^;CH7&J^+)ggd+OcQ9gR{ixc!I=XQkQ6r zO5#!)`ui-choDq`yUH#?kfvqeylgV^k-pLoqEklUU>JD(mmn1g&f#wp(L_dXA>ODr zo&zs?YoDK*_7ZlvwlfcOg7%TSXYe5-z?Gwi`;gcx{&s(3C^dVA-R}*V$cR6d?LGqb zLA!7i%ge$Cg_xwnusC>Gc&Sf+vn;$L;$ERVn$r$6DGlIFD`XU!2Mus>o zt)6VfT~{12x|u|hq%TY!8q&c1eYutWA_FJL@u=Ax0wIgBUv%1Rf;$+QRLwDwb5&); z;mG(N)}$Fg5x(VCZULYeJ=d#w0L9RzY2?V$j$EA`Mw2c=Sn_du-n3=xA%`x)qbQmt z9&5-dhXuV33JR->6mJA=u4oQPuwQvyWaAm-CTJ}j+M$fimPmArE#ee4x|?j zGeE15l2&11iR;720p1dKN92J z6K&(1V0Is2``F=i7t?Ik|8>@OW2uAeA)`z*tv!&SISIZa+s&jdW5I@|(MiRM1I?sD zspO=x@T^^gw??Df=P*#CNi4iJ^9s`2^5~TQt(Ls_J9mEHh|FFK8NF?MRIVBfy z`9}i3T=%e7l;n$7;}CN6$FSXd5ptA%WdJT_cKq=Os9qn<>cZwel43=^Y0+`B@UQEb zOG?W%x4rLR^P0QggIjaqgB2US_)~V+Xv|$wLdr~uC=r)LPWfD160>bB7!X(CX|h*W zK3oSwBe!y1H0gO!vH+7%6JR<)FG$>Ie)i?MZZSI7jq0I0yhUvSj=QJ}`+w__6Ing%+3ffuc#S=Ksz))}zg zn*%9jv610eN7g#M#2kWY*6B5|5pUy3PNx_4@o;aT(`$!)kJTr)*69^}-G0TBNgOP5 zFLLM{gqTA0C<|0IjpMLU{ir3-9d6S>B6$M0{KEX^X@hgaY^{+&NPF2j zjGqmg-^>OeMbP*b+sELABRR(@6ee6^ov$M=h@}2zK}0rj+JBkY#7u$Yz1c#clOcty zwTEyB(UVVfC>>5R&)GQ$?GmIe*KY4axSPOk1k7p38DQLxjkOj6K(mb!2?SxQ^+uWW zUxHL1xPiY-M2Umum1KfJ$D_TKO+TPsDBX zZBswjVDlpXObJU0Epp{TDEADaY3J=d#f%-01>)adU$Bc{r?0cx>Bo_tUi1eA-m}t3q9YD6m}+Tqd1t<_)kkoU5-)qVA8ft;_^0mVEtGwM`x@sW;r@a zMfUu1Y~}?@`J^J9lSUyxh9`ljH(5eb!+{6t5U!v?A^0yrDiBO#_Gs4UEfDwq+2^fj zR?&aL=WRyxJ|jHZs=sWs3g*T;KRtW^&AI4{=SEN|Y+2L=j&O$e@ArPT)+6)T>raWV z{^wcyO<4@_GbY)#DOnUp{Z77aTUcMWoIsPW+kEnMOO+HbzHaJ)4qM*f>lVPp!kK;D z3Sy4$vha2LCQ@zob^C?0{%MD3Nrfcu3A3-8B6L}LuNV!7duKH9mSF#-#67Fj+@@b* zHjz-;G1f|#b4CNS&Ct!XQ>$VlyI~iFwQ^8E8n^YyN6o_4S3p~zY#6uo6~smm+V&HN zg!S?G321mI9h;p_^WefUG*x0_sO;zV5%$|iZquxd9E$uq654Ou$~lo3HV}SiT0a5` zd+^Hn;~o=Jk@JY#G^?6ukCDoaPMS@TS)*HFyIo4#t=SZL`;7j3bdpofb;Bcj0kkuj zB2Dm#CVn|OsUd7YHID0K?eGZiY8;q3^GNOR4DX9{jXHC(xh=l%EM-Xioa*%9w|G!g ziUDC%a;?qwnkQb0YwBk^b<;dvfvOlyW*H_b&K#V)r%Eq|79!vyc{>g)R3Hq{6raqQ zOf+K-pwW72G=XZR_J^^s@GGGsgY1k6S@_sL99egx#t5|BjE_P0~wJ=p<5e5|j;fNbUi{ z?Pf+>xlh{4D%c|GQoqB;M^SWl?lqtYHFFql>k0E1H*6rAs-MY;?NEp2;K4b6Zi*6K zy=Ra;1Aor<3b*oHJGyS7qm}2{zDe&+DHiILWKir4uLw7zp>9jhaeUIY1COc+HqVH6iRA2-cI zvPzSWo0NRqY_bDv~?3_eL zaDLK|hQ!L|WDD8+ANwQ!eHC8F0d{0 zmTi5NZ;l-jy3m2oL!9QF8$`&z#$_r6HYZCaLPEnCne&&(Z1xtmjia>bOjGfYvLbee zU2uhGF4sgw?C0U2NJ2g)hd~>)Dl|5kf}Ur{*BaFJoiZcf^^TzC+|KGb--ly^m_=XC z)aVqa!x=Bh|GY2q6Gl!SBsqPKk<&=~HvHNyMovE*SUj~jX8`L;Eq*1lw-M!cj_X&XCM!QIJv<-LNMOr-ifqcjg5^n)E+>zby1u+~Gf z`#~DiWaYZvEukhK+k56J)a3dXA8!CI9h4@ec#)e-;4HD2Azno2{^;rpQ@lu=sT$b& zG*mY2E8LoSo&-dAEc6<}qE&1Lp1!|@5Yrgvytj&j#*hndg0@-HQ}gRLk_P-92?iw(Pgk7*)hV2fOx;%wk8BDJ zm^JBfrg97KzoDQzV+Tq%JR%BDa_PmjVIJcV`Vogt20K@o)O_K1KGefZUIzBCRtO4F zal`XE@;$QfLH)I`=%s9eWHb|-Zsv@zu;n1FY>tT3foWGrWDo_~L9CsR_^5lajy^Ki6-+A5 zcm^QM5J!N;Az2o=JY_XL){x3wxGcab+j`!VOg4cQ=&tS-+MGKzW4;b1FOyl9!&Dr6+?@@ne`LY6yQKX8kV0ENq4t?x2tgdQtA zt?xcEWpd$C!#`f$)cS#^?Z~n*mpM6r-;GWVtI)q5?u12>UnUgjT9V&&B)@Nx0KQN1 zn~CjP0zF$#dUiJH**em*OGr+~HPo}$(*5VNfN2URcZvQ`>m=AAn@?V$Gl@v88=RZu z-X;zQeNP)NC$bu<0B&Z}19(!8W-i$ygD;P&;K@{Q-}m*m*a7x^(HJ|xJ2TJI4v>HQ6WRfCUcY{Z z9iTp+s1x3>1Mm@pb-`7^5`%yI>aon;@N4UT*n;X`{d5q;mrI?V{_ZVY^8BrO6#qEq zvn#!#D!26^_(uElmkfN<#%U+^CDreMZx}E1OeeuK(8Dw*ok}H61b0yZQe?_XDh0z5I!)XXs>yb3Qi@_P`-C9IC%+QU2Z7qDF{;NTJiaqVHy}&i}1i_do7+^=3 zt+QaXycbzu=oX{G!V~#>Ej03On6>8cKXdF%MH63FNSbVJU000qm8fw|>)M~+hZ6Wb z0siHj*s1R!eH*(6_e*1w^_5ugv?sh@!0ib+oA^CpEpMc$J5vcP1Id|LNnbfUBOQ|d zXUNJPX{L3<$n-+mJvjHNXpbh4&&Ve4i)`}To-mQ$6D+ul2D_cdlRTGE z`++J%@F@ur34nt+>*cYK-xDmjj5>#?KeO-;De(GbJ^VNchS{vY4x%`ns z7r#di_nc&YbC2Jv#0Eu20xSTew`DdGyOzo%*+sp}LrF!qJ%*N|_U^tDuELSWK)Xy@ zOX*DIO^LpNoepxYyCRrx?+HlhNPrnLViIO3c?UBbaNnzff!BS&^bwM z(a`&rp{1;@CmRycwZgNU`>`R&<6Af+5u=gNG2EA{S=W33YJNpEzpZlQe1Wy^8nAT= zJZW`0F2vSJhB(EoktaGUTA3~m$DZUbgxaDWUes@cS&M!yWnS<|D4qH4Q*Vr!ROFZR)&Wxb6?c>L>{B#LLRMkfgrBF-EkI4IQi-nB0%9 z(QSiEcqgvgnOi5J+e@h1UkV!W-CpkY`|t|SGQQi(Jz9^YYjnFLP0)zz_KZW#(d{YJ z?amc^w~bSdduN^KEbPLyyuP_YoD>`*Z-mPI8a_w34bpg6}jOB@_iNgepAkzLn`9cgb> zwgg^bY;;p{k-5|~pAMh7M%N_1v($G#7xNQ5^WzlsvY2{ll(_Rl{NV|LFEaUv)epZX z_#%@Cr8Vgi_k*Hqgiu`7kj(6&8H5$?enWz(hX)*cUYBPf;G;JNB)s>MWALDHN6Y5s zYgO*7Z@wNn(CRoV6%$d^x2*3he9*SN&7J#gZ_XqX`b#*x*KEeex|B4>~ z<(ZRg_MgK{yfs7oet189acwh^Y;4~~`GAU2Ux^1I7&S%Xh(F6ywDjlOl$!*Ylnf(E zjt8qKtXPzX3=3?ZTu#^qrLD$=c+&jN`wACnu z+-p0?`VIoIn_J&O*!2vx?XMm$8_37fu0^Ng?tpxQiCo~qz~gGdMqWb@0g=n`n)to8 z3zvIZKXxr_-#tWqh-lx*a zP+Gwhl(vTsOnN#noudO2!*W*(!-Rkg6UIZ$#u@}p86%{tKazcLM!wE##0H|ilaiZ< za(tr?cOY&=y5C`lzo<;*>r^|2d?r^rpO7LYISxQ_cXIO(9i4;r zGUp)b`kpvUQF&eJ^? z0j8%Qn7z_QOy!wcwL^G574j`(bwTa(90N8UPBYr`Gci(|g7IQhE-S@uoDQ$|qBxIK zs?^MAoqI->5AcdRsHkH_2C_V(YFRjUP_c;B>&M5ys2RRHCisI+FHH9@wQ5Rm7V97v zn#7z&*h4_y5aK$abW4msJKd~&HlkaMl5L#gaKsPf%w8It?FoPO(R|2_nP)n46Zv}S zsP*OuyXzagT@x+NF4xa^OW<8^gM`fpAsYd)sWM0y>r5ZeBNLpWJxw~XJiH9)o=h+p zAWuV3q{C;rEvD1yG?z%6QMuQ?zyU`REtwoF4m6!Mx!7ACXgYoF5=qyj{WwqZ7y2Gg zXj;HW8rC-D4Gdsw<#dMbH(Und5R}2~*oL9U(ymxIy>kuNahA?RUD1Ga?=1cRq8XL$ z32Oc%PKnWsDr>`BL-!+u%%uBiXRd(;FQ+q+CCf@(RV{I@^3i;CUcw7>O&SoftgtUw zRWSY*;JG}CBh=w!#I-yzL}|s1oj5S=D2tc*YQ^C|)L4~Q->fg*;eC3C1f7{wT*ElG zaX63^ny)=RX09;_^R;^ZG*i?Rrx`;(>b6e_-)((-dUNP5YOdS*__FrQHM*TD#Sw{_ zjT&w01$0}}7~PgQ&gb>Eva>Sr4dQd<_g$3sDH&E(%-mL-$@tu@;IV zPD7@1$P|ycwOE4h(7|!B)W~<>-K{tYY0M`fR%g$M_suB-M8S7pN1$VSsCE2dXbW?e zy?7ke)+suO4zoww^cEG}>a{_l-Hm;h{%xMxkDJnwW+-}jytRhcPW9uzqND8&eu&3^ zO-I`u)WQ`~O(_mLq!$LMmdTLuUg}ZNEz1sx&O~{5KwBOp*$Es1$ADb9ODiB0IZ%lm zTfMJznc?Z%_oA_}BKSwO%O1%^b%A3u*ct|N!8dzI8{I9~8qPRvy^FmCVV#yb3bc*$ z)_0&^TE$f)k?X?wctDMXVA_F~b^L7E(6E$a;^WuPpJU8tl%wdHcs7~OuHxgJL!iT5 zu%fA@f2IjI8Q(j#PY+XE#fB;{j<$!#?Bok=y*p^hwA zR-YhB_@*!=F+cWT1s%dGEQ3!5KceFyqqN9#L} z!sdJ1#I225=58CZN>1euZ8?XwL|Bb=$-0Ri1S_M<0~m$dh0zU3;U;}(fx-&jz23Ij zY*l1f;@^d*TY9fwSG_!?P#kr=?V-7!4zg0q&cgd5+WHRpS_q*>L9Gx>5ckU_N)~@y zlMe)`I0vEPvFTMQzwoL~+?YxkxgQ8$_5miF2uA%tK6jj3dZJcN9flC;Dh}!-peducI|GM5|@`yvs9iZOK zS8U9klRE=9*HJiZCIQcxP9_|hb*rlfoJ-cf-@Jw2e)s4u;M{IbrN@94fAVc3z`2V4 zH1E(Zw_O>?yz7PPsg zB^G4#qRbkLXh$;p?Uju3afKl^1@Z4;ORTr{ z;8PEuuPn3FtA@X=DyM7o*vwY9aWpx%eMYsL+3IS~=#bZgXRFf>FG~a`<@0@FY*4m2B8HUo-| zHC3H)%u#f*mO5#ss#A_S?wtpCnztiXc`3wS@Kkk|P20U8`6c=(AuUHDPUN@y+f^Dd zi0jSyMEVTO(?aFSqU}CID-}1*=Y0qB;(Mdrx2WAoN?utJj&(~lVzNOaGU0vkWj-Z% zv!p4BS~yMPG$kYG=Hb{$no_RNtSL=v>UbE2%sg(pSB(2!_!c(PnY0#WNOcB;9>9=! zREBD2NgWOQnagj;kO>lM!(+QspLFF7nTpi2P-@hq)ab=2d@uE*JHwEv1@XPfl)b4l z44GQ8Q;K|?Fq3QH29s+P=O9L;s+7D5)e=WcIbS@mOXyezb)*~v(!NZ)MLf`*ACRo` zgb4c+tn{{=hjPCQsqm~aDJcuSLz_IIi1MCWwBiwbqQVe$~H#d};J z)kUhIS{v`k;zv^B~dtr)^aG*n*Xy3lv6M92k z>YsF{cr@*9RbeSxXm=}`a@l4tD+Haa)&pFpHICML3J{+2Bdqq<5Rj=_-zmD^cCq!H zkOyrU-W*$>%3NowEfj2k-6|1Z@?B`uOU7~Xh>hcnL;QD}cH9SKtYxzEy+a#w32n?18s)15qSB!3 zh`~80X6gpLzh@0$561Dt%Rh3qqML7&z{$a5e4~%RQ;`;$br2qraz74xFzvDlQh}hS zg>R%Oj|&)6OIDcoS`)AS#*(l@Y|1`5z5{*0Hr6Z3`0Io3W;0NP8i6C&AFsli`s~h~ z-@wrNb{ERuSI;E#tNwEQ2l$3^D4U#WSd*%(G2-9}$mojk2pMeJ2mC&OO3xm4p9Wy> zi?iJOF7=ya#swZ_Sl|LG4~q4KV>%8DK9Zl5@e;0-*BrXwJ&2Pvsu+6&nZq{5V$6pk zlaSxYw9%dLGhqiR_Fkif5D-;zCzjr84gPeEbst*Ffd>DO%kTI|Jb8vQ zZuIfUf?Ll}^xebxS59=B1Igs0xjMCd-IVaJ0Cy_e_4|>L$OEE+rBM_JlP5dw{TA{d z!%E(<{ZWrNNW5%ri>hdna0H@`wx}(d(Hbs2jCSl}bhMVd>2)bwrHU}Onw~Lw?x)wA z80H-mT|_UFOM9Bm!|M2*80&1>qZUjyiT+7L79JFtHlhU`Zk+af5B>Yad9s6!f43whUl}=4#z7vxv*q9yFF*fb8!1D*iIPl#JIPvqbSN@)jm(F zHSz3(V_x^1C=0L&ZaZvCw{gMImN(;41160tGi;x^vBmbzAE#{Z++$8>DR&&~Z=6Bm z6Xo)hT70a53yks=JfKLv9`3B%V`9^t`~_ts*^WW7X+T#g63ptN7WjK- zmdf1&6Pi`xj%IR~P&>JIX4(PD>L)-xQYQ2!n0!ZwLMIh|Cl_#d$j9MfC=L%A93BcF z-vnFD&Lt$@bQu38B5dOL>xTbZ%Qe3^sZ<(Gq|l5bhut#Qg|3n6yH#I}Dhsxgvr4+z z*YEBE?Ae1YWszT>V6!Z`T~*=fO3TlQvkfZx1!HOw)hW-xaXz-v%)vp-Ijk!mSZscu zSHO}}sGe<<03zG%d$)ySUlvo_9P4aw@LdYRt?Zc& zU>7ub;azIff~1xS5E}Y9t$0qfv6$sT1nSCJaZ>qxcDERT+eUT8PFY6Ve$xy>v_3@C z?V?1ihVdTxNn@lCKkrFADo;?rc(wv}hsLSjrWJ_T^!7tDHzHn?Jo?*BQbuOL;fwim z-Ws9E$7DNyjzK6&@$s@FGp{+bmb1-S{Y(ovQ74#U%g@rB-8D`kS~Q*^qoi{(&6!5; zE1U$sU_#2LbO}lmO6n=jU`dh2NuWJINz?Si__Tm-_72&B6EDp}TPLRksX(v-Wi=5E zXXYU#()|zTA+b{uO}UeK?%hek+>5rHdzX55G0(kQWwvuq*slvfHmkXJD)wi7?)mjq zym9V5ZZP-a`MIZ==`c1l_$YQwCUX-)OCYC8MTz67Ci5n#jzNQvy;8*NX;@ zB5sUJdWJT`lz(|YF@0^DhbY5C4v=8WrZhIk}pWMwqO4r=yj^dB{fMpyt13E2XPw?-Z0Yv$%j{UH`4)Cj&iW212PO=35g>k ziDRY%5)N>~1n07X>W7PZFk(d%5XosZo%z(t7F0m5GL&R#F*SEUAiBm<0ZoFNcD^>g zKs>_GQeLExG*{0F)>Z{Z2c}oaxPS{o>{&EKP0Ehhr&!ys=?N&4oU&8CApxgS6IY>0 z*%8FQ3t?SL!a8X@C#==O2Y0h@D1Mq3RK`&RafYXjW?G_)h~Z;}buJUT&tM-EyU)C$ z!$UyYh?6~Yu{!D`eN=y4?tc+Ut%w2!;c7AkDX5~c%Mmpd}4>4 zNo1=Xry|!|HvEY>k6}UMsp>B%=zlt;q@y5@uk1j#ND1uPZxn{Jicl!#`=QCqx_rX? zIp1E$oZ8ZRFc|1%%jJ1yG$Wc#tTtDs6TFiMGl&vO;Y6Ev5C;%xosorZf|!r65c=yv_d~23YfgaWynkHEf5}R<^^N%m!#{Q1<7iuAeZtLi+6n4 z%)3k?2&V+a{sPgV@hXvCDw#o#Z6^gNmJG_f8I&}U*aK6b@l->$d;!@K4cW2;JB=P1 zj}#h@CKwMuy0ExuOGeI2XXFfVcK?TRhTUr_N4hcF4KHKvHPCe?^$9WWH-B8W-EXuF zb)8$WzwrAF!xm!2TC7qJ{{>6fO2(WXWtK6$hM#+*jHxG1kYqYs@G{0=C1Yf7kTGq~ z;utvvi_I)!lvXmPs-cV#4q1bxEw`32NI7XCV=CVuW4bJ)APnLTw2(2cEoF?sM#dCq z2_dIZX*l-ASABL{6NH z83h!~6wW(o7&c{7tb#C*ldW^6;T$Gp60MVWDJw^kDSRV=1FFX(x|#PY-EId>&q*;t z*s*p&#ITZZG?fVurcebo1o-=O@39H${UlmukZ2+3u(N`|!9EJ{NpEK3AAtFNj+HB7 zgkfb||1ZduMI=|2`Pf)>&ZH8x`Q=p%i~bROyw#4*FR54X@xe`eEL?8ORBEG2ZebJ; zE(fL0VjPYZ!Bqzk7-f?zu@p4^bbOlPMCPEj;CEw0VJi|#tp~G=1!sa;f-mVXMu|Dj za(}6b%z$V{7%YkS&LMM*`yCvlqKG+Nf^d{l%@%rqv$;!lv8tAAWJC!0 zgs^2~c_o~-35PNa{R;BUNb-$T7cHZ%DF*w48mOnK4CTo0d)j({gf^vsFG?J2;Q`Vm z1O)IkTp2MQ;F=~_Xe-Patcp^%k&Hu2IGMsWs>JKB)8x0#cC70ra4(s1@RI6v7pj-z z2tzCRW3yOXV~~C>87boZhEuGCj)D516Ao^1lA@uwLvO%BUmSXEn3%N_dFzcO6La3i zoiNYyjvv`0e*~Es@6680y2ZTPN`wUqvvK5y-7R8TwK4r0gBF}HC!%2HCzRQSlU%?F zN{x2PX&70uDNc+{-^=K2?PHky*w_$G@waofd zmr{l*7{rQB|A)p&oKB9DxZY%w3sEAzmw}cz!3FkIX%g-Wd=|mypeW@fnJ^B2Pg>&{ z`uNyH%gA_9S`u!U#im4a|t=J_-`Nt%C2GcS$A)kuK37B+e}m33WOIc`ng6mQ%vre^bUUWn}#PH|Ugakn#VZ zQ{Hc&Q*Q6m+pJUmS!b(L8XD@9Ph)@Ob;?V9!GLS6Q)(LOl)r2+>y(3rpXYQ+ZaLO% zfP7b3$@gb@X8C^O%sj8{^LDHvu`P zX7LwLpozo4BUzfKaQK*;i&@ROLZ*L9GM(}1Q_yw>7b=+ZQ%pa-GzDG(3Nk;JwRD=J zKE{&m#Tgs*yQFV+LxP{*q@SnOLRB`<}gM}#I(@KTHfK6SY>_1cHb z;hgfgMX@q{(SJ{-Pk6VXOy3Mex<_g~r0nIJj7&d3GQB4$(#+}off$Hwj}^|Ahw&%D zmNNI!3`>PN&VnuHV+Vz5n}b%B>tIk-F*5iWDb$z0w+`EVywW-o`+6%Y?xV(Nk35GMg7TYjR<4t@@p@fRB@Jucd!i&FQb=Jju?z%bYfz;_m>U>WrDON45(a6%DOz2QPwBVR(J@~;Y1EY?h{th8(GCYXP<*OSfnw= zH_8%#?-CHEP(P}n5}!g!JY`hZQ$}@-E7W%h`I^__L%}duNVY+bCmOSv*bRQhw$Ko) zM7tBD0)ZnV^OG5w5A%&6PX?PvTLEHVJ)HE0QI5kIN2N3`sZ{!kGHa55w}3UFSi+Wq zZ4pXOm(Y`@wPolV2>sC$la%tmu~=2?ao(HhQXi}hilY&wA*>J+xU@FdzNzB`*Vf%H zN|ed)wdrAScKNz@Mn|1yDz|WM6Q}4{2nk{uV^__S9!Q~SQqnIpX~5+BdLsaV2~e2H ztjmgP$?%Bo@qH<7cI?FERnFeMGjqE00grThp+Y9Nf2W)4%*{rWl{utfsEaey zJreix_8NL9=SiBk9C}c+bzFcjZJgbVGA84(UX%YqRqQWViTuUY#i}x+TrP2~RDF7A z*tWpiQLgT2p1|$Nnrnx~rKRf9+V<4?1P4hZB70f%z7c555PvkLM25W0)DeN}ITx>0 zK53(z(#Nf8?8TIW^;sm$x{yEcR>D`<@Y#O?#c!yV-YO}%hCuvS# zXgc0fV1=MVx@)NxbGbYIe{#G`8kTaxAp%L1CdmSpTl0@bO_WKOr;t}D>6Bz2Bb`!nc|7tr z$s}JEkxOPFLt>a$WE5nmf--->SCKNOodh!yA+dZ?m(UEAz$I}%S+h(~0P!2jNoL%J z%uqvS$UlZSp|gJyWQJAD%AcsSiD)W+_Ww`G2op0QehS$TEiuamZC`&FOs!{`$3^mIJk)>?7Zb|gb%7(`lvf*q)+3+d@vVmjyW@UrGN;c@NWkX|LHq;i{ zb?sQI(_6}hSd(lR!^#E~D;rc6vLW13HrQ`Xgly1T%Z7;-vf=IjoNTDa6F>AMs*6yl?YZo3D zSi9tlaQ-IwA*Af3wHw8(T?W!yY`u1mOl#(9U+Qbes12-%a1@fL4~&csOF2u1K#|2a zO_uwy!F*g^?rSLgj-#EWwQ=h##n*!Mis%i)w@p{bwyK#$iv%XB$c)IpA9u0V7HI97 za&yu z`T>)PU@Vre21~Ll6`Y&3|5YhW z*Ssy^&jz^O_TC-L=f0bT4bFXi8{m4U+5&$zSttG0L1i~lAPwjC__HIG3VdEEKs3}q z#lRt)YH*#y@6=M>8K|>II#Ry!E}bJ4$f#(_ipXl~D&z3qK%4>a&Vs+zili)_eUb@5SM|FgEav1@AXPptI4W%V~s%#+Q?dYfn;JUFFTp{Rm;vGu?7$m33 zR&C|IumDk!ZBIu=x)r=I<=;yAoaXNgIkU@1hatRj5n;ev2{$?Pu$k+I;WcyNf>U&Z z(pj{&x1q^RD#X0Cx9~k2%=;diY}j1JDkRC#22~TUi(*`tP?toB24<&!vMpK8AF{c> zasI#qo=8;EK%<3VKV-9^Np;yt@LzY;=jYI!2H3&B^nYmRdt)(&CRxlOB{zpI&>RX3 z9APls5QuqtFSA#JB<2Lp-MmN5S4k4HooW+zZWTt)!$#H|N-nuIu&%=wJ%eR%X5b>V zx-^aK&Oyh=-KOf{ytuZBm$wr?fu@P4V*(l2354lv2NxVi4JTamF0L8KhvJgGefxUL+7^rFq-= z|MUErBj?qn1K*&mqrLyW zk0+iVc<-mR9+kZ(Z@A=baNl?AGNLhd|9)maLNxA9ZFvB-s!xS&zlA4%d(y)d$w=`rKMOCve$PnAoaqc*RKz=!6a~&Rk&Q8iUm^GVd!Q& zA*ng|KAwP^?00x#PjJ>5qH#C zxzGpyRxtlo8d|Z>XQlh-BKlZ7OwX4w&saR~t*7S)bY3QHJg3s}FTL3GKlzgGNF*LM zP$@q#%05piufluytkaM$?zP_asz1-Cv_L6Flug0kpH(2B`I`Yh)X~5g-+ff$t@wNO z-QUo|gD!V^&>31&HY6GSult9*`sXqahfG&Po-x^B9{FLVt+rr&OONUXUt{9$_6Ek9FL7d4kYJ&RIix z4;+g4Qj#HX#p5um-66!o(;JvoL;&yZ8x}IL1OvpmvDABMAET2!(8={u4w)=y)(4&Z zlR6p4C+o~QtJh+?=6I}S@sSi(dJczrWjF^;N+#UHz?M)<$HcK*CpGaA-r3;p;CQSx z@fqG5=^Axz0cgqElB&%_KELqWfuoOgYyO63@rk&_ z_j7(ozOZkFCwNd|&p7JFdC*%8zIL#+_FtEsUiRg^mwTDe$ zfD{mQj;@IsL&|J%TRe*xO^$tG?b?BxM>vj(_bum#t#U%=U=RZtz_7%=hPwC|35bDY z6f5lEn8q_Nu#ksa5iU%CeKFTswi@ znV*rbTs=90h=JbW-l2MOMg((>F|!CVR;)b%Vjx*_>YXFeEwL>vI4=b;kgzp>-8>k? zKtf_h%4p@ZlzU;isf+N!j6@7nr)`aJ2zF{dENZePw!-rT*K);s4@HMi`3ibr%cS)t z^h;w{cMp|UQsn}d9em~0=k?Dl`IM@g|MHTqQMo9{X$M!ie7X*mXGWoNiazEl|2rf@ z3nC$&Psy;{BT?4upy+d|{4;4b@AmxhCh+>GiPuNvvYW45Q>_-ifLk2TY9x))DKPtV z#OsrjrsnyWQg?yZM@YOr5gqq%bP$u)t;7olpAv%@o}bz#IHrL3kk-4Tb&NAM8);~_ zvScMEdaJk2!D}z1*PdSqJ}BncU#{%~4E}(rlk;`4=y@mCIA2jGr~PW-YEHV;wn)Wq zkJ4`~NRu^?62an1+Jd@kBP{bqjAT zZxG&^G5jsOHO4_bWxZFG3mLQW+)Ey3&VY#*IgIDZBk+M616j4Rzu~!;$#b7>i{X3;WOv zRKY?ivN5RwpM#>S=Igl<<{AYWr5WgG--*7k4jtSPSRsy=`HVNriEMgsUQ^@Ab8?@t zg&tTez46w9dbp3Hza4wHPom($Lt4DB8!C4Khdow(`psHRS~~V;3Wa^5@lhmZwG-)+ z(r~Ym zr##`Fn0|#p4>gUvl74jiTH8wK;P_d=6fJXk$Uv7B&$x?Z1e+=y~We9n|d`MQv@!mSC6w>LTz0vZ9EB&WiFPv zO$xd)u+HwomL}|1(>K0Fs?)?kV82%ttpM{f*mL>@pQM1O>CQJ3OenJPdS;9D>&K-$nAT>~Ub3=n-3bSia0v40Nz zO6I(YhvHM`P6ibJE;Vw4d=1x2{k{8nUu_tdwvn?RsQdPtcRy*fG!Ug-4G?Y22^_m- zT-qj@9~(98HwFDT7ll`&@YKo;{hYZG(65_<5m-Up*__N5ZoKb0J10=RhMN%I^nELF zvz|&$3(VLY^z#A~<-HX}2{sMji>h3ggrYW3JGW@eMLlqxn-iGH7xi7=u*A*McGSo~ zQ4R|u+RKu^+lYA(PiB&&#CH5V_|>PUEGJOM{n-B3w#Cdf=0Ub|Lc3XR3kI(CvCE+< zKQ2d=CQHaG)%0TjY_xA$!J2LS=*NGVylLUqKs{Gl`q}OiS5BJltj61GJ14vtqm_%% z66I==cMffvJuYo~PGI#Kt&4lfg7ES+acq)Kb$a&3LBSOvXhq$nMISg(+7Y9j=1HZ8MQ< z9MTf|#0m949<3lt~4Ue}|EM~MEf=B_gGoko|-ftT6bJ&d~v4elPn-IwDS zPj2OptwZ1lxnZ3#Eb$~-R8!am{BkMh#o(7KOqo!G7ClM%jlL)z+y-CVdO~09zs%Zp zJ%9CL0^a_h2TIn{FCXk~5chW8_RjydNWG`Fk$Ml?+orjH+%TGNjJL)n5D)LLx((ji zqPdhn<{^4>x$qKtv+VrIxpH{RPcUK}o~EQb;fvxHt?P{8T^tNwND&>(9+w?9PaZ_J6}q#{ObAmWjbHF zx32!ljbWC~_8Zpa)Ek0nwiT{#?YjYwOu4+tY93!e3AYZ3xhXK*Bk_etu(C8G^mNxQ zA3heGA?TCl`U^C4Ny;5G^je)^yMe<@M`d`gh;-d3^N)ShDY8oP0X)pEw|(-CPHLKS zy=PK$z6#YRwLM+O@T_2(&QD}HUiciLVe@0GoNrH$6Oi89{wYKWLC(G)$|TdTor6s+vALB9#{XN)%Ni87`9k$^NuI2 zv4MBv&a6(fdFJfn{SDpUx;J!xTWed?wJlR^i+bj4+w+hW{cMZce|ep4Q4hu>*%tM3 zCNmREadvoxz8~5AD(Ou8d%DgaGV`WzT0hjfRPGJ)1v+~^}kP^JJ0OSS&y=D@Pr0$|!xRd5%df+?UH2&VKvpmf=O5qwVPX)SJPt=09 ziCS=03aZ;j41}ZbuzU|5)>Xg}Lol5jF^DOrnwWADiCVBT5$ZI-W138Q>^nTBA!9wfb`h-X*;D zgZIDq1gPD$>gRnzqNIO5$Q(0Ov}4&v-*(CJc=7bu`NOiP$_<}wpE_0aXh>-)l5`xs z*7y$K%ujE=SB#&1KKoW1Q0iQJ{$l=AQNJ&5r2_e@dwu_hV*pPN9%uw;RQ)>rElN?j zYuiUhO#!<99yGi7tt~;AB6%xb{|sV+nh!b=0eq*|udktgumS;&OnmU`3)EjYt;3gK z?^*ijOX6^U`TF&pUQxP}vRFLP_Rf~SZ~^p^GKL)8nMYm!{K7{sIs?Uguj0{Ea+m?1 zkdxDhw{8qjYU~_FxU5vVl8t=v+rQYUT>jPt+>^LpH;V7px~IzAuNlRU8Ky=ROV)^g zrH*s(uZ;Ou?zohFUgn;vasR{OVR~M|JY(^^NBIo&2xi7=VAPuTiv;ay=DL`MZ^~nD2;wmE@>l)&D1epVN}nCm_#b;vFZ=$ate@%xHsv>0Xwj79gMe~SG!;=9-6JYa*< z?%6~*?H((fc1ApSXk@0t{tO1R<#1Xp;k1!0MS*+5<8(YuD>Y}BGvl<`BlBhX7C3G8 z$o+B#x5ZJ@jMH8`+Zpv2S>v?DRhkkDoYpup(yjqcs~=Mg2=g{!!)BaTHY%dK1x{-i z>fzo1r&T0B#@V};oW0FBtuWT@S3Jq%w4nos7B|3YcLf5c724pm{X-^xV}aAwd1)Uu zz-jZ((INQ{Yn(Roto|iB$>Frp_(;bFIIUp9U39Y88mCnQrv-;GkJGAw(=xa%2G)$z zR>c$H`Xj2-jMEkXr@du?(-r`yCEV5;r#(6i6oEowTerYz=Y`}O`dHw!^9B|$xUCgV zYnXcvIBlgZPFoF}_JJi%yJ%pU@J%@FbqCarV;P=rOgY2ySgBZVExJ1b&2FstHx42le9=??HX`4EQK(rLEZ-++**aLLeKcTWxzVH^paJwB`u2qsGnHn=$4oEW=) zODu}-N5!YYYm`HlQ?{3>K;h1FK`JK%DRP0m5V28)kJDTZF_I+?0AiFuKqbWta9H8d z5NSukWBKcW0iE7}&+5DhpPf#XTi~;Xxo&GM@ma%MkDr)p8+^9QX5h01B5t?DXXTUA z11#}b`Q)-5Z@_1@Sw{eo3T*M&ye-2fh6G!XK>o3?Nsc8WK8J4g>PhtJlW2;zuIIsF z<7E3S3?6GT+ZZQD?tTM4D_nO4mFsNr*^2WLaev~s=Z5Xkz)UT;+%pMg)TBvxh&16G zDmM^!`|+S`o?KEh<#>xK%+d}VxNK~$6+dAaBYlHi>~Np8pyX^N?)E$y_FZZYA6*c# z3sV$)_l(J4(%O5NpTt+6(ZOlAQZYW-VVl>^W*k$Sarzf{?F9r8=cRQt0fkb>tMIJK zHvxs^+EBQ1?vugKaIbVvse>;^^dg^8fZvwWZ!JmhJM}Cd*u#SKzEf1iO1j4Kwj{m3 zK8_gK4^!JM@L6G&|04K=^Q8B}tiZ*f?reb1F5d-wR%naQDyOcs!e^CJf13OTeAck% zC4g;-8J}ga=%+K_)~+_O(pSD52fXzny|1b?6&%48$w;hRTnXxU287lqzSsneS50MP zH#!Q8mnmb+tmz=ddeTmX@YEFqb7K-=WJZtFA}ED^%m(ddkR@2!Iu>Ko`{}87qY2tmV$m@0=`lKrBsf?-wzqF1X zzzxp<`D;>Thz6Rv2mt1bKbn`T0JVmGQty!HH{JO+YHPK&Un&8F)rKCaJ48)?X%sq( zLjvzb0|}IE0v4rSy_mTTx4V><5q*5*_maGo9+oKwmWGcjX zFCK*?Qz4m-@0fr=x%)|AbU7%2#akb?A;cDqxvI!s-6(X9>9@N3iTG^}{Z_a7UA_Th zQ*)Y3+pKTJL6KuW)NP<^g7{ixyvyNom3=1wms8u|v*|ez_mgtW_-uMkOf8UZ9-n<* z=))Bt=sy_+yhjDpB>>>HCxlsQ@7ET1@N(~98Mz>LtNdzbPN0#{-Ac!V)<4L6{8F7x z(odcnTBnUmTago3PhakOuK3H;U4ea7+bwXnoQQcK3C$mw3Qs@WJwIWyDv%rh-|rV^vS@AQMoHYxe&111Q&3dG@H)OI!GPr`>ID@wcBi?h zZ@QNx0G;KEx_`eOgS&+qxmhxvgQRsY!U2jBXgh6*a}GZb>Qk%rD&Vo)kLy1)VsLnT zc2CHjrg#}j2AC}+V3sDE!2U_Z)+?*z(`X#qCG(^2eavoj1psZnw5Xq>H@&?za1T(_ z4|ANg?WQ}=q0yRU_x_r0YvF>g#t^p0l&d-FY`~ZE35oUZAyFq5CrtlsWzMw9HQSwL zjlt9+95%Mnxo!JZK5>EZnd8Kt++Jf$y{YPy*}VT%RVQ83{r=UBLKkh;QEA%tItHVS z07l!qAx697jTmhaVYHu)vck3qx9Uh3?S2NM?SzkqF5u$|9|C5JzsFBfYL?;S;q7e* zuGJA-+k!CKnw|Jqf^TEh8-Uauk8IYq#yA z?^BNe3Os-G83WSxqy!zOYuXXIJGYt#X=|v6%Z!&PL z&lP<6G=+g;i`@WQFT1tH7Ot&+g=Y?@UL&~no4u{^%p)-a*M2(fAK==}X1MkN57(Z1 zkB4h7*uu3>OmOWy6I|Q6AzXX-U&FOyU!tL*%Ti|4n%>}RFaMLRJ?sOAW}bQB;t^7I zm4S69wF+`#C*ZP}--xoMv%Bmqwx9bg%>!qnDwx2avyg;u)VwCUfZxlqr=e1=qdu7>&s6kG=z?yyrYt?jf_0w zr?dfx$234?>vq_pvJ02VY??bN<#Lt{C>nQpMgvr~{zikMmWSIGb-Hx}RQAprP}!HZ zMLk|Vh9zUy#*s+|2<*9~JSyAUjLNc{;w&zk4O}*T9dKFyVb-{;`Z$lv>VeBf2(nZK z$TuOa{m$WgsaK@H_v9?+Ge_K`MEC7?5iH1f@tX&A)8dMj-l&u-NJ@Qi1=T zBMuo7K^J=9UEI`v(`Bqi)EjQ585+v{qNdz0*#yuQ5wvt`B9X~CNRJT9vzT(zhyZ0KdEa*Rc=bfI4Y+>U|cfD`#9O zC&KX+KDd2D$*pwUzW5(o6G@^e%QDfQvo77a{`=G5gQ@!NEpITL=l4BD9;;Pf{F#m- zC;!Hvf1Qrry51{Fdh_)Pf&|0%6+K5DW)w}Xo0P)F#B0XVtdqC6$<_7$cP_yH+kJoY zq+vHPPTNY82zCCd|Gr`tI;DKoEsux3{14rI|4(!FwL+5bz6v^n_y6xA$;QJ!zXDu; ztaLBNDC_o@A0L5bV9Q^>zl$k2d4myCFz@D{pTcQ6_E9g&>D|5mDK>aypfu}p8k!WE&bsAC%vMK*X9(%rlF1f z@wqb=#h)#ITGlhFcK6eve`2aS?|yGDp1fYZ@Ag|+9(r~D=8N8jl%GRRBhzfbAID*_ zFy0(+`zdOjboje_Ydy*yxSj(Yvg@{fR6*$FGcuSlirY^-3b%8?s&_t_F;%p@|LU)L zM(KZg(m8diNH;D$0PV{EeCA1j?2=v68!w(JD!TCF!d_A8Z7D-AQxrcvJ^Clgx{>?S zAHAYv-}U*4mf4GQ2}pNdwe$65B)gRyc>6=xDf-2(zJao`-pjv=CcJw6`T@1P>tX*; z*Z}~_gG9*&ldo((6u?3^0v34RX~!u&fZ&+`3&$w{$#~ys;V@wG8L1*O8EL_kUch*+ zlf$VQkC#s96nB89!;`$_I1M?Xn(=tW61R`w?a1Tt zR|XE98f@z~-SrTqjwBCJ3%}``Arq%#a5zuooaWl?Ac$uDrrmm}lNA-+l%th?;{o!5 zXwccKHAtIt^OXhNwgZ&O8yJ_}?A;TFcqup$mWp_{#23|*F)GVwRLnV>s?XOpLv%9d zF0Gm(n%J3R8XP@nVi|`yZCwf6N=6QuMCo84lt?FBar%T5nHp>O>t7s#=^7|eQW}TV z>89_(nCb|tQ%~aD7bS`IiFopN4lB;|_^p9=W&L<__~>9lh&td?1s3KG+=<6)khry#Mcj=9EP7r#`6J zlDt7*tzL7|A}f>LTf3I7$*yFPmB~2cC^}qA9k$HM6w*_@D;QbJIbb#5PRZr;J|dvtm~*ehuklV4L&@W zKA>lg766`}dv%6&9y4lDo|>F6k(fzgk>vNlu2ujzm(q;1aF|w3xA!RkBLwd-y|KkF zVsLpkaG0+6jS$6$)By`XUi&S;&>lluF8+VC9 z^0sgq2&eL8t>H%6x=x|KVfbovD?qz-^kuE^YsYD=__PYSHJNCAAO#RFujSJ!Tx$kg zTW1>MOU2&<*CvnWjMh#o0BhzC3jR^z@`q$&L8}plcey2F_+~#cr&PAVB$B#(#=GL? z0C2sFWT-qqzw2hJ$ffWYC2Tb_O~wQCb>g;|!CCYgI^^&GeZeeyyVeIqySZU3m=)>C zT;n}{F3m9k=NiSgXHmJjEh^Wz?By%h|5jOfsl< zl+_;a7|UFv+r??Me7A!JwL`bJQnx#o^WCnUZSUCTpy(ji?b_Lq&6#U-`zvWV-|Zhi z|6UzY&<@>}H!9-0on0dl2OvW^&~erY{(4lVVh!B&Y0N9rlqS?%AntYxl?!uG`SC7B zzVhtU^cFgZ+P65PNS@C;affa(*_013#sQv3;2Y#+Hc%u>(ST} zuJOiHkxyZHXJJ>GLT@>J!wtLs?IjrY+r$rEmqxDBY-1nw)C;p1FrF2=h7yl?lC}p$ z-_hDVCgr@UHHzxT=zMK^bY9?oh3~wvM(Nn2Ta475E3ZDPTCo-uAuiP{XO5cQz9Txn zmpUJs4QFds&{WP%Z)HU1dvl$y^ZClLH*y71%GqFdJKSeGpMPmDI=`1Xe^brpxz(1t ztKC47%HZ_nwL|9xj+Kd#5diQ6P|KtJ9XJ-*^z+L!Bb z&L_p_@pbC)Luq49v?_jfE8m8kb$n*upFNy&=^BkIOP4n0hI`-Qx3KXaMdkMWtN6;L ziz2$WKPY;_57(lM2<95^x=vbk+%zV_yI&z>*!UL4MBTWaf7iK}q%V=WF^p?@u3xbB zM_gm9%lbRja}!Ou^dTl%MGPj|?c-tmM4NM}{rOLjtd48Bu;#JzNmTxVCfal9U9NKN z{R^mE(-oChHh#rdUUo_U^83z)0bJ!}e!pmcz%?q59N_edtGssU->AHfD%V6t^OY~^ zm0{R}Sy^UM3AC3y66GBairSEBC-u+Z#H#R_wn7~t?S{%_o&iWY#RkRPy`&cZ4LhYL zH_Mj}(ota2#xe|hR)!q+*>Ipg#&p9w|DN9Y;M_=4qx|k`VVE{Wd@EKj^#|LFC*F2~ zXX)U$+Kv<&qW^$p*iq_t7JK}lTvzYJu9gm}J=o)P=$f)Go{$h1^$i|;&g~sU7 zc6NozmPUVxoi&)#iDgUId#{7u6i% zUriYj-$LTB*usI_%h#s~f5QOGq5YfKh1A_VMxfb{ z!OfTe5%DC~xPQ`bpOtWqA^3gn{r>)ap8LA@+MKi3+G|fI;q3ihA0hti(Gw#!zlz7i)V2M| zB1M=g7jfoH<|94r#Mx=!Mz|rvdlEM?jdQ4lStdm3lDK zK8h~NT0PCBJsMf7vfLlq+zIHb&1&4~QO*+#kIJQ!;;w? zleFX2#ypJ7;5eba6Md2oJM@sZtR2E|2De>ex7MU^QV^Q0(`{zwt)KHzY=SxI3rSiR zufm2V#gp1f>g1%&W_FvF2%-=}HpQS?euRx-yo+e=koET2%Na~zlpytbIH(iaMym78 z!gbwll`>Hf2m1=L&ux{@+DExamQ2OAC1cMUE%@ms#^Wj3KlW8Gk2uA`4FRqVh+n*d zy0-62#IN7EtQ)An&gJsPZB6BkFMbR+IYv8g1(q+jt(bcWTa;w7weQ>R*u-2Ot~idH z)y{RT&h%w^Ce+pYS9o5DVqAA$J7(w7^0ZpCBq~GV&g{$;x_GB{+hA|G663@X+Pa_p zp%-@MsvR(s0&cmpJ9Fy=)o82ySF}~U@)FzDOD+|g1Lk8hb5i4Grnn35#_ag18ESWo zm%=^yQl`CRWw^;@#AjNjK3eAR#TGyw+3*unzxf8+5Am+K6m1%fqeloEZTe|-c1pP7 z3KM<$X)Wz%UV*pN%a$R(=BHCv!0JfSl2g5ytx?fAZ5q&);Fh3Iu12k7&=yHI-bw37 zi?K|ol5WadcJvv;66S{t0*#xHIfB4x*jxV%Oah1UcHXEC~C+}P8!p#;$BMbKU=~T^~@j@oskE=pPC^a0Q*`m zQ(V_~g!g&bCUis`4a+|w)?VF9xz%cGxXDw-zW$6nHXtfPPWHY7lK)}sW#>7YQU5%u z|EV);FLP8%UAKlSo-y{KbMnmLQ5jmY-!{M$*Y)+Mo>#$MF`w#x9?05r`Y3C@Mg1=r zd&xPuIuuJ}AnZ*BCO>EFv(HO@LH&!U{ueUVUe!mLGzay+WbC8P$?L{PWvIwLVxTFm z>+7#LuZ6v03Dy5<0c$UDQrZe8M`dt)n2B)PRa_`b!9$QHLbcz5G-hVj3mPzl^*bh_uVxlr+J_w<^kLkZm2o4v-+Y!QQ3gLAY8-l#Aa-wi< zREC-jp}|#LAzOxL9)(cnm&JzgLJ%}P*H*F@AcmJPwzO{7(${8J_bu{*DO>%^!G z@gRgy>tp(s31Ptn*>4CTgF<*$!G@rAS8i=xh*2NJ^g@uE_@QhC06PldxnBhv!fSyH zA*53Xb)S7;LkM*fJy)ardnN>#2QT8Y5=;aNVcQ_n4_T_uCTFN=jX12t8nL{GHR7;x zo5KA^W{tcT#KCn1xwc$wg{PAqu#0jF`&GI!(jstj<5N!V$@A=f718>`wXc7&6$33= zUR3YE9Lc6R(#B;qyGFc}+}Xc|D*!io)6o?7qVw{=8wic;xA~j0(4+T7*&g)HpVT|; zzGCgwUdpYje+@Tj&)Ctb^><+G zMHl3mk1OVsDi&nGts)5R+XQ;muV;^-vUiTgolm-t|f z;9;o_V20$j8`cOOmNO?~c0~DSt!CDUv#PE8A7cO*2wT_ z8R8)bp*FzOh6%x;Oq_!dE>Z|xD%lX^gRn+Ire(#yDA_O)<*f?Oy1bJj2uw4ZSNj9lB)y);4BxUj7YaTi&JZ?G4m936y0aw%cz^7cus{0m3NQXOq1{TQ}Bzyt}CLm<{&cjJ>j>kbf8UC1iir-9xNQ=v4U+N#o*z-mRm%yI$2kgVzdb0L$?jolN8|z#{}A@q$o{IojImdKCp-mv2eOY{<5GY6O=tL5DL#Sn4h=k!{8kltQTU&sx*+5@ttHBy;vIRQYe`82NT(vvm;nXJ)plzfp4U?Fd%{Gi$lNOdj|avp0km(l@5Ku3yWta?W3v?d@T2 zHlh=2FAGs_t^6b0BN+SCauw_q9jX3rU08e0NZy*CzMh_; z4Pm-E*u#70neCAl!>6FDzaH%TGt&i&g_CFI+N!7vBtM6+T`<~1v~MkDE6V9si!xNm z5Be?AB82RN2M03t$^>By?2E`gtbHhJKiNZ+n}ONN*z>}LsU@%L0uve44bW}WRzo49R=8uUPa@PKaN8t12 z_@jsHOWRLm?H73#!CrO&_FG3SVeM7rs(Y~4vi5Zzv2S7DK=yU*HLQK1XFcpWmtoHx zdjj?}Ih7?-dH5}FxlTV*>XvvSBdy8oPWTHe*6IWfwu)UQW=tGg-&maIEUA)T#$5i6 z=5pNU>zIM~Q&=p^P>*dK2)imR{vIf%fhZd6ux_#2sxC2>3fU3vov=Pg2C;6Q71JqN zPuW3q$`uM~+$Gea|FbXnqFNXbogo>=lzqkZr?G)mIZ7F7A;ziYL!FVbE5 zNKa*(bSi5(*>P9;{@vPBhf+G7+^}C9X;CpnRh#`lHttw_r!4RK{$cBOk4v#m?(Byd z6hC_%B_HN1-hZDzUbj3k=_~zheVciXqmt$Jrokz5dYR!cc*@Id2aug(1XnTA6ogY> zUi)SBPMJsXuBP;8YSPgV$<(kDo5xwjr;y1n<1Ch>Ofhw?jy3O1`Skwj-dDPJ6W)+L zbq$`)|0hC?mzfvi=zU_}=WW{6P90mcW%dir?tdJ)1#noHA57)M&lG3QN&cEx+$yl> zoJGqK+@06kw;bV{-O)X)#g3jGW2<9%f6jrA;Jz>0v@1dmM;U_J6x`RW0B8AOsJWYT zBs~cZV+U+?O2$PkE=sFgElfR!*Id=hr=b-WKMQys9nNj>Fr!+slu! zZ|ja(3;s-*binitvyZJk%5|C156SdVpUvvNiev5CNLuvi#k}rxvdMyoj5;2nsr$Zb z6i7_6NuBPDd@PN^!bM40X4r{4j@FtU=f8{l_`0$*mNTE;uOW@)$>VZ1lUVjKjcgX) zGZ!&$U-!(CR9CqUwRA-c|MWG%b^V^e?bu-0Ypzoo%%T;nJ?A)g+bry%onh?Nv(g`0 za=_b&~*-vt12nFdeGgC;TxH94e|8{PN9r82uuK5hV z=XU`XX8n`u^f%t`8c2K`QqPkc=)gtt&*46c+mbrRPvfH|8Cyy^VI`oLP8BOAhW%Ol7Y^!4h-V??QzXf6;hRkXPN9 zW5_e6e|bIMrWleD!HZF^#LWZs>R*26#>&EwE$#fr&#_}Mpy$yWrsw${-;Mu_yw#C6 zdG&GD+fti{@K#FRV%@x1Zxyk!1wepG7;hD^u}knqv;^J?&-1ivQlGjFVxvi&9h|f2$@w;N=KMvH$UNoTyBm(&3UYXXKgeyUr6)#~<7s(e?KET- zbH=OQN$A-1iDFOR3hb<&m$cY1D8BoO^N!L=*yI%2nN$|H$;9PlqdwTKJ@2TY%ZoCs z6_t>6?)DmaG^5sdYSvprgK#@a^-`^z?WxO^nF088YB!k?pkDdvZ*$|bknxFcxkl$I z*A(1bDmOZ#Ir&wS6t0si`NjI22uO4A&d%AKaw1S#1he((#Yk|++jXo!F3jA~YXF=6eU)`zH)Cpq_;C{h% z*B-0&$F!|l8nRGog!3xQ}s9g(j zHOWR%kQT-?@*C@1U-C9454JC2s#<>kP5In}_==|Ic)#2Hrf6_$5sFsid})Sp>)IyX+-}@4$>r-lH=ai|d31_9y8LwW zJ8!;kHZ}`pJ*LLECuVbvT0_Y(_V(&Q>0gj5#Mf!k#?^3*bI2}T&aM8$+3bcQ599g|{YJG_4e2CrGO&~Z zBQ5MWjp%vn0*t63WrBB2Tnob8}9$N(5w#%C^q3%DN%P-euT$xS2p8S3uz7jURQ#7ViDO$b%ST(-q55a>E8z7;w;bcp3(H>OPR{U+(=Zx~qW4-PfAiqb=RFNT#vjy^!wb8C89`s}@XXCn;uRW;C!m{W8k-=i09yqenV>2rlvXT&YO1tpKV;<@y_ zNky}VPEj-L;%XkP^qW@;u!4qfoQHZACxjcV@Uq8tcQ)gGqV-GMKN@`>aXiUei8u;w zrEI~CV^z=TMvS5*ZWO$r8_#;B7%#z`wCZLZ{mds0GNRyCoSBX;_lejjI9p54aP&xT z<)7aAp8}`*$eKKh=v%5^UrV3!W${^<;uEZYMF{!Ps}aJp!?tweR^(3H*i!X7-H1`q zjiW#J#*Jsc@4}6nVtRU1u#eY+6j4Pf(GFa{ZcKIYq%~baDMCN*9 z?xuN<>`v{0`SCG16v?cs=*m|G3w!E*noY2tdP4Wx?YV`S%-HaW-Q5vQN9~c_d~f|g z`t@0QS+@)oZ&JH_Vd-=I!Mc9wMstR~QFna7Ar}k+*2(E)QUibKo9E%?3z+I!t$P0D zU|XxQ2|9lfM+A7Hy?s>v@_0J2M>x)jaoCWlq|rgZOH}xGv`fd0q-?45rp@K;j#$ zoF<9_Ytri9bTRJugK5>}Mn7vW<^XFb+Sy9~^=EQ2lIoFCjT1PQ6 zl*xitYqw-+CHFk~e@iKSN6)>&t!g^g=qGipUC&i|Om20vzp&;d+{O4cnVwMQTx9%R z6LPn_tMRw;l{X&7llX9&&5~xOuMVxxMrZ5(21{_Gv~BF;x$q59eNA%^qUpVyoOQjy z{_u!nsVNmwc>*-^9lpZ_RtL1W3S6JBP^AcoSoK=VRDH$ z<%D@E#(JAV&&mTV!< z>M|L(1)u98Y7ysPall*nJn-Tsx*l9cPNwK;3?e5FqsWOhIkEheoP@bkct7aEyQmAV z6NNWO*`l-9v{2TrXlsw+TK7`B!IvkGtnk>PF@d7$KLVI*HMr-3MTGN)6iB{_iwkDK=V4+RD}gCW9)LgOynu=G zA{47Xiek5~qgY)Gip`>82gj(VNpyo1xXm%>)_~w1!$PKv3k!{JC9;>c{#7iOhzhB& zD+&-4l;5W2&w|gQ!|+*J1)n#Fki1WXWE_fda{FW}xR(51C9UlzcG!{B_494N zkHdBOKA5Yobg=Q0a*|4?!-E4cmU3b&UlKeSe@JE_N1nZtf$833WCL5b8?naS^849h zAztYznf^?Il66kx&&^?+ikyAxJ=H z*&8@z{TdXjJA+~mNp5To74s(X@r@`Z8HQr9JyA?d#geJm0xH&zidmDl6)46DK(W&9 zC>FdL#Woe8SRfVq;DTcBGvO@}#ne71=G+a%4TRip>Pz*j=AJ4pdh9A*WXWHD}z4 z$^`b8xU2Svb3kxe-;KzCxBKX-xhV1Q@5=HUBIzc48G_qFPBlWuvJdY51?X7)wr#se zvBdk(vEeJNQeS^rdqpIT+OmBXN`6Rq@&bAHp5KnTTOrcq|CGIPrIqS&!%0+7x}#q= zNTNvZUCS$4X{FBfIbJQ|UU-mK1+#a@kH(?a;yn*eKyjn`gQ^wax1;xd-|kdODvz~; z(oc_-aV9@?Ze6c6tKocBFaGR_{IeKY>*p%^|8vCVGKpENnljQYmOM}Dmz~s)GRbW| zXWz2*B-490KRg3k_p#A0PJvJem6Mhb_eKUYNcY@FeD9 zJNYi__|elYF)Kju;cPzGvWL(3z!W4J)32OZ@trV=;7Yay!2V&M4lNn+6d1 zx1w7#)2g3DzrQZx9{+LC4b<9u+@Wk-e49AtJqSke*_@Uj80YWOrF|^=aCA2&y>qS4 zOaz@h;n@t&0GV={04XH7!sQ=wVcz}injV-vw;tQf!Bjr?Bo@B5I|^?B$h5pRBMg}e z-!4>Udj({l>iY*ZYr)CiNa`n{cL}YkhjY&yxdsmTS-(9it#WO{6%eAyzxlWbj-wu3 z=mTGMjtM8Oi=;=NEva5>|cd-H<8WCqMK9c;`Di-S#4QOOy);9E6x}}VotWc;g9ag88gd7k z_V8!6A`G&^`Lh}?w>`#2xC{+un9TH!Nh|--;g?oxBqOzFn0>ny1*p`L(z40MF~{5c zVkgtILSJHT-=6dVO$<8$$=R&IF6Yw!l?E$D`?r{G#RTwWJd(3X#}lM*hz;QBUmZ@4 zE^ZUcg#PqO2h%b5j7D#@sA3wO`48&XQvIb1*+wS_i%^o4r+GEFN36@xdNkU<*8--_ zDIK7hp@~ABJW(>6jZQtm&kclR1k>^A3CVuC&lp^p7y{bTB%q-3qz&sqOv;JXbjD+w z**Z$D5MM_pbz(n@Tx#euwp$8C{K@PdU*v8~s~jRGGoe<-G_a@V*u?-bSYongfNTkU zMuR7soq$vQCSz$QlIr-AX8AM~n|@I9qp*5{fiR8cN2@MT^cg+BJ&7<({kGG>P?Fn- zre<4HenUM|!I^~gl&2h%AV8zvVpEH3&9*~WjVjle`8EpEa77L4N$Vq@Dxbw+Rjql1dKE2aK^Be=NlIi)P4TO5uQ{xw_KQ?{?QdL5a zBY|mrdL&X+qVc3qQ`nWi1gR>mM@VA~QdK66VO-XCPK3+Xy8%`2E2k5rXg8#}Gb zKBTIY5`S6PjjxKWVNzA5w@ign;!N9?paE?$Dd*n0*s{GT^B9rEbO&~nPE4brB27YM zSHFe8gcvMQx@EKZ#IDU1X*5SR1E@?Vn$u8`CiR+p`)MROWO^#K^Xe@6jNmBgAv4UJ zS&pbYmnzriI@F^!J3QSVkq8Yl;dcahKQ<7hVfy6qm=)*~H{#56S~MD3aw_g_;kQeQ z+vGBlb9Y;<(S61kXDf)()Q|D&NN}XwRT$%wU#MB8(Q?L%*%wbcN*$TV3pXUlUbqak zB&PA7mL|BaUtMpjR-y8NRQaFIADPCzLYqWPP^R*vy;kFNpHX>{!kGmsdzcf=nTg8% z`c$x$XV35x)t@%dqj}ib*6bZ7ELA?Dl>$07`pUPDU60DAQsuenY~|cs;ZzBL+#}2c z~16F-mXc5U2rbbB@dl@Cn_sH7e#eKE!d>@Ow>b>rN3l(Mpb*40sQ zy>APL?(JuO8~7um!kBda446s@F_qd@$FaaDt|XhRt2lMi$R3K#cMUFXBR`4b91V*4 zRvlv-AM|qWg&NqZ$kwuTExS02t1RsR`jxV;hwY!ARh~3|tU(3X98J~dI-%{-9 ze*M^IVnfR1BIlv9k1%IO1@s$5!fmjH_A?9{?+<8>Qv(Q_+nKRk>+1qw3Y~yDu;zKtCKvvt z(S(-v9PPVN)*^aa`<$;rF7_Oiv}vX=M)7w~5a1OUwWQH6(y!y+t)pL*UPt0zPA|wn z6I8iLCAOZ}aNOI%!CPc!whXecvd`k$^f=X+lyW9-GhaQ^Pi)#qeVwLsl*N<2p>&jD zGxesUT#aMWz-nT@0zVr>?3bjo{78fU^4hk!lyI`N>$WyHQPuK|weF*3x8LbwiuF$< zfs(SMbeSq}@DKenzVKDIF>yGV*Omt|OqrPfHBhiDj;6OsrhO@rNS-`n?ySbY$|7@N zx33qp-y;p(mlCFFW0JKkbYJoqjjM?$CRe$hTqbCjsiFdO*$%V3lM|euPt4t#OT*%{dXUS9+qlp~m$4EbXYrX5Tv?#a-t; zCbL(!@{Y-s+HQo82J>|Ok7JP^k3;$K8q802i?r|s9Tw}cNX$zIld?e3m7k9Db3Y(f zzQt1+rBUQtl*yL0xit7AGzb z`q}>VhN%&s_RZ^LC}d;yy{z$=FGt9`gy4e+%o}flM29>&aVqAmMxsguWNP&1{w(v5 zs4|sExdj&7o@ZliL12ZF$z+Jugqv(2tr`r()ycja*wDNMVN=n!6k3+WM>U$^&dw>> z>+N&dzHKXL?K?BlB8Hka-8cJ(xS4?v8mdHxm~GOs+BVwO%1XpJohpsuu@%(>54jU;kY=6YS`#8pKW(q54kiz4*-5hwO4TG8w5Qwl!5m*98nsL061 z!Q$|O1q%>uJw;pp!xpw9W1O9lu9D-(93qWz_J$%52cV<(iC)QD94@RNMW8JtF0_7A zsBVh*+ZCugYo;J`bE>DPXTEU#iZ z5n*X{pSQJ7iuI-M=cR_H9F{LvJ*PFF3lGVMovR~qZAWTTaFo4O9)I7s)`v_gF50y| zq{zG2Sy3RLG}YoZytS=0Q7o@9Thz%mB4eUcUH7a-dDiRgRm(A-GF!d;3IDdfaYym^ zc5Sk9uXxrrMgGMv;KGHybAFPW5LK@1oargQer@(gyrlazT+Y))}s7)7QE^T zv_?HOLrWoS^Dr&X5(}B9xx|q-2vbS&Ii`}#&Xbu+m>y=HGOV^rrWr-2ZB(tz2liDwG%HCfsLiWVt8I`U~2@yNjNAU$23mcLmA=#uhrzW z$})CRhJ-*#Td%BZq8Qott}P=w#VfrON9|Vi>D)3RLrym*dRbo23g>UXUpF@^S=@V- z$>)baYl=alg`NhjX?C+TO%--}={A!O%yQvf6f6Pi$05|OL0HfBb(||_YYk}46HGXu zFD<)Ea1j50Pg`<38PS%3f&Ue4xwW4WZ5bQ^b+>|{vyVn zXC+*M`Z>h;Iasx2>yL92ZTAAn#@H*X1*sORzy?wMU;8^V_R6-xQ?Pd+d&%l{to;%< z(Y}6QgBW|>O#y!e>_f;t*guf5=XDU4!akVn!>l^6_NUxLxnF_}V(gXw2vS?8fej-2 zP=6U?uM`Qd!F~bR$FAa|HFfG4=~C%GI&cGPGpBEyxtt_4U`4i4UXx>r}t{ zSFF7}K$%p9`aKwX)kS&Tf@v9&VX!wDmh8^hTV7Ga-lU%D_n64qi-+Lxm58hZ!>Eey1_%ywYi;62Pdc5mKxB)l#1uC>|G(u8TM`ilG*%Dka6U*@q*)eciG?0g$>eZw^MhP$d)#4nK+_0-gd z{<$n6H(vM@v$cY7%A&S}Q|bZVI#9KXH`rI8oc^z&{=(GTuzyMRPyI_7d)|LUEk&Jg zZm_Rs?3DvR)gC7F67~)Lb&S1ojxYd(o&#}$4Q+$j`ro=Y^g07n)0g1{)!*>aH@yF^ zP|KP&JDZ}G>VWD01!}2Ga@!fM_>P%vk}LAM;e=k&Yz;BRb^UD3E|;9b5OgLyF61i~ z)l`jCN(F(_GbAIKE*3t3AFwToI1Rni^&`c}n9CNc9w+2ka}z zKI%8tUKuB}w8E#i8OOT|d&O9)e~eWuW6zn&>ks>fWIuNb+0%QHx+qT$epZjy_mOf{q!QdC3t6JK zI<$@Kv}o}7#^M~_7U6TuMJ>^y^*>L*T+|IjS*03(8><`Z>c3f7BFo^6r3F7Kc|xAc zt%@2mTmJ9XLBlUA!c$^D@7O5NtX9wDjX+R=eGpXPl(Fpl4@+US_7rgFDTQjCHZq>; zUh4hfCUMj$v#m}|$qy|U#`Pjs6EQKypk!o(mEG+{CVDSu0>36`by{QS=vdY$6ctU$Hv2&(~9k)VSM>T zQ`^jFz8e?+FInb}Ea7dFwVm-s?&Qp^;m$K9riGjezj$$`z4JOHWn(&YFjAAvwesf$ zkJ>P8#5A*x^m1_SgYyu)379z3T=#Lwj1tnniS;`XgIC1qPWRh{Z2Ek}>A1oCFkj9T z@A(lbygKym>(dJFAKQz=26GOW-X~)5o?A)=6sn`>F3Ybsp8Ny7PrLxdbET zwE7F>%r7~Kj)DXXI4U1nu_0!|Hr-fc(7xF=5=3g6tQPyLl z*WyV;E)vbrYC-u6|Wb4;R({`u+Vf6yE5{7kby*2-oq;tpBa_>TYi ztn@i}Q(QvOTz=_L=AFLes3v`ei|5ebzRd1OjC^O)K*mu+Jg;`k}|!*Y$U-#dcmg(o(-^lk~5d z#TjT?y(jy(-JcJ2dBJYRdDs*J)pVYZk*n`31^_3*OLA_f{23B-vaMT$jCKqs*YA^# zz?*H_BY(S`<2vQxmeyXT3#irI-CXf<{S1IN+&q0JsVLBnVf9K+N(rJ|OvQ8%W#bmo zJF`>TxC5rW7*Xf^c;jE=dg9z#FC}@)CU3o_e)QJGd_TMu zlefYZ#q8iX+K6r-fi3IJ(Z;45jzf?)D>p?kvk6rLdFL;)RHx-hO*VE! zK2n#3v65J9ipa=b>gId?a_8jDh^Bx#-HQN8-G&oRyqOmunzs~L6l z6MvxF(0a1zQ<D<`LZ#EwFC3iYR z?nR6aXPTQZ-1v|;pvOMVARfX=Ug)7m3mwaIWPc;{xZ#t~W4#_vv!3kxlkUv^H#dJb z9`NLwcbgN9|Asf~`2`8vjXSCB4eMjnl$8@Ea*Rj;5@eBm8Ygry#tB~xG`{lkE8{DN zZHWaeA7!kmUgOpI(|qG?ZZp2I8S+?`cm|~m&W(pW@QJvIQLcLVn@$u)!}&|E+^r)D zA{UWdQ5DIpdq!sy9hX0mK*(2D0l?-Y-24w@<0tsV!+UjP<0pW`ynDElP_E~l%Dv*7 zOd>k$9`}V2b%2xjcl$P-Y=_cp_81TI+9xjwWsJ7|2jRbpK0}~h?iH%6EPf%vUpb}$ zH|8(h$g5_K?JMlw4IdGR9PFfoKA0H92k{B;QLxcL*@!S!ojS%qiO&ZcpSVBx*sfzU zS#8O>V!xXkkpAYQ|64wpZqjLMBW^=fIVy^KVUK6wB597Yt{-6U5_2%Xj;r3^diGZ@kZ<#511}LZ3JVGcs zkqA;l^jSh2bKbcEalCtqO2t&U0z_&^oT-~*99DP7CnA@M9+6`5B;4(}5=BeKxbT2J z#{48Wt0&ASSbM^D!P0e~q#Q-HwCv7z+ix6s&Zk($L-UCDE8Fo2%5g8BhTgvzbQ72t94nDu^icxZdp;9&yGP zRzxbpF(z7EW|oa!d}_@w1W-_xvIz@j{Dfj71S{uFDglxOH}B-*e$1Gqh+}cgCB*S0 z*8=cg!IKQSu}_B!E}pfFh{gPieR~AjSzQdd`SzYM!iEl-|Gm-I+LipYO)-+DFw4QT z%sMn@Ui6qA46<0(;VdHq^&-4g6Ap_&s)#~Rc+YYM6OmZ1LltM~P{p^=$LL0G65Y5P zN~q!=8wmk?8*_jd#{3$(@#^ITpoh_NP>AJIn?0!sPy(2bn>RZ)vgGW#2X2=pnh5kU zPY4b%ZOrE;2lR-golo>KSUfm&hQSsin_!DAbl75)uMS&$*9a>P-`Iz>PfQS9q=!DU zAmiScH2D&^ZI>VYtKbLy+>S0FiZG|r3FKe-kU2(O{Ett>3A%qsM+V*sSe_k7O~U_1 zU!?0tT!-Y)Mv4x$En*65QUfr5!WY*^H z=H6hYd6RF4w9JDT`ge&hKEWiu7}iS9rZZorAogzDJZ-chLf#ZzA8WM!3V;vlSDheO zuM9_Pen7$%ZWG#T`Bs zy$(wB)csSwOr)Z3SYs%0b#o{&yk|S)9yzp7_s^@@#_30V;+FB6Bl!HK-P#y+N61Z| zIIRf{Sz2J=cLeW(FfM9NsORFU&}djzzA2hVEnM0|dQq=i7q{!vEMk1Bo<-c;>@S~r z3HQ0pSi}j7jlVlX?wU9IL+Z+#pN%Kw;k4>cu!v?WSQasQAA!AZ9+#P&!?OMJNn*mK zrt=kaVjtuA3Oez0A7eW4rE%D6#%-Bx^hD1~Ghh>sLa)znFW5xxSg?uOIbai2Cy7m5 z%CL!#!65+l3e&NPw}?&TtkbcH8b_8*4AZlT>fU-b(f4QK-qz~bM2?P4^zCX6Cs{f+ zQT|%TCKeK#D7jAkyLn-j!hgvg-dr3pI5KL3Uv%(Y%?}- z#0IL7qhk}l)YTY3P6~BwVh3^(yN#UqlM~5OvsMlgeyYYH-S4Z+zow~%{8+Dp4i}OW zhYJ)*fsR@9A}2+W)Yz*;3d%Q-lK@@e2~-$7A_~v#^5uw#M%7eQ0Ov)DPG!r?Z)3PL~z?ahyk*(&Kb{8#jLF7ob z-s(q=C^=lv!ICS^jzD0XcKC0(<@1{+zI^#>A_bp-2-Lpp;3!ep1i*uYP{js973&C9 zyifGu3_=wT5PevI!koVms&gEQ)f0Vqo9M$jq7R>t@ZvlwmPf^^Q7oHe9J3F$=xk%x24C_YdA0qFRvRXcXmi`SrPz``3RdbYSg9|-Dt#)fR0moNQM=BqE1M#3&*Wvxr12ArkR66{|-v z&MXv@IH6eZF#N{`^n)89#YmK)2+fEf(({L_za(wQ1zieFxExgz{_x9r58KVwfnfiS zKm7Xt@rT`he?JgtCKrmH{FZZr&A${Xsj<*?p%!kYjCqUD{QVv}j$_3|w5`bM`$hKJ4D=M{vklPoHiEIr;YS zu?kVry_=7Ypv2CN2cAGhbI+5rcflcVzU5K5(&}OMwKsq^WA6{V4z^O`^LtH8aGLWU zmMx@`(chc|2wHk`q*Nf6M92F@t3=mqNE8ICIV&6J}_{3m53X4 zReT1HXH47R0$<5HD^CJtjeZ$-eWg`(_c};`|X$mP__7vRR_>U&hdE_ zsO?~#u1Av}Y`O)$HLYwSwb8B{iPvU>G=?N8cf*7Hg@C-l1&c++oYHTF^qpkHwQldc z5X8wHg2Y&^b=35_k6i1h>Guq|)Se{cXAGy}tN_p<-k6f!Z(*G%#vZh?Cdd4f%d7J9|$hA68 z^|OPVAH%h3Hd&Q`>1xWgN^}1K^r!8H&9{$SEB6=SPYDKG>w14Y=0b0o_lg8gQ+K|w>VweFg@0GfC#*IL)v zMMIyl8E43~%Ht-1TUtvDmm$}B+;gm(s{z+~JU|A4zGhr&j?D(ltm(vX8E~!KJYhA| z`}QzsC^yfl%c@Ust*+%D?=%~!*#=x|dW@gwVUYpXy4~8=Y=CQXuC;m-Ng2+i$_=}I}&`BMU6M$T1#j7b%C(oN3Qj`^+a}eV%#QD# z%_L=*MU6M$S|Me)s2;(xT&p;)u&Wx4|H!pk_8KAG^hvJua+f`D`6szF;96(TH{e>w zRX5>U7oOOEQs`?L&zuPN8M_>uwFT0j2?db*z2I76=0^KLJ{Ee2d`()o85HPG_= zjY!8|#p#Xsqnvy19mEwaNl4I;YkgQPPlhZo%e6kVS)EFsF~JSF*7QZUQF$h5`5AJp zGB5YvTn)HZ*%0sF=`;4ESGQEash;l49ghPkrVpr347gTlrSRKb23)JO(rS(Fvk}*N z;4btNx%n7lL#|a7zc+3_Qd6^BYu4u7vb{)6ZOpZ5Z{(x$7gV_c*Q%)!+UzmlS~XQx z<8+^mxYmMYeZjRJL*<5Ct7N{Ps1%Zi;8*q3cxe|~vxAth#$0RagyX3E166Lowc?CM zQ?50A{-?Rt*u?Ho@=GTrKLf6{a=anedT2)zu2rk!T5}n$Rr6*iR20>&7u-P(&Ub-peh6x|l5Q2Yk+8Bo?$lye@ktrV2_vuD7B^~Zj)H(I9_ zJ8}oIpR24_&}WKk0RNfuP|G$Q&CSP1&qP_+E8DMK5gP0|V(^)G=$tpVM7?o)JYWFy^bBQi0d zTQ@RvYgEAUZJiqF*8XjYZarz+NVodXN6@WR4BfgQpgG;TY1sdXZuPd_BaPmdlB1(r zqxU5rZlqgZFm!9h5FC;uTD6gGjV16-N4Nf9?c^T=x^?dVK)1$xq+1L0bZZUy(9x~4 zn$xZIgZ@Xl)py88x>Y$>5C|N~nTR$+x>e#L3Y|?qJ@BVl+Q5{3X?ab$I!rEJ@ZIneEs z=~NGUMnb|+^#dqZ-1K)&x0u$GIXV8Vs~DnydZojjGg;1)b6PEW3c3vC^yjqIF5dii z#JqwO%d*h)JiCx246hQc7^@(bl@7vkc$tE$XloSVYeg2D*}f`q3EX0q049MMFRcyl z#q3KU_#FxBTHSiaQ2fu&=&4iyA*_Ib6b;J>UrR=rl6OGh3wlgMT7Tl@f3i(|t(P!#W|7oHy@X*g*%$dbEFaD5 zfM_@*VOZmMm6?mqQRg_2FfKh1U zN}^p&56X+Tk%Xb*kWLx2<*IhPPcW`48X4EgfkupLaH;|0>VAnB*Y=k_!MIl1GmPs( zTO-C*pv)7WmN*Ufu4ozI&^y((pjP&UWv$7ofeKdXCd|!y2b69$qw!mv9>OJIAA=pw6lHamlsL zR)LOsP90V6OIo42Nvf<8+{VbMi2JK(b&MUkYFAM@R66UJk>ec(Rs}5UHQB%Pg;Z$c zB)=he4ErLouU~$gwZG#k+TgOmp2NVA$~<6Iu$K_l_13q6u~!;M4Dy8+S^Iid(RN^6 zBF3JV53CAURsh-i`dPA!=(CT|o1sqVYgf@eU|r6Pz492aD%b~+eW0HMW3M#O7xWfh zX6yF{Rt2mpkg*3Jnc5my*0*Fo+)u*T^BSO|ghftRTJXv%to?Aps!%?R;U=~2z^dxN zJrbkp>yZqJQ9a6#b%FCc=K6YCc~@guOjhmzfiEcLX<1A*jIb(@TH`5%(SETjA@IdK zxn8Wybu zr1B#si|HX=3&{N%^Ek;pFKg}p17WY2jhVtE&*AM_$Acx=>5@4mTX>ycvfcaT@E<1) zX@qCLbr$l;{zUFWHws(5OMh-|-=;6(qNqTs~ zcl;mc&HTJ~HoIIZE~-f_kktMR6-B|DjqMx*%w+%*&TY3Mdy{p9y`rxZ{WvUz=q!AR z;jSZu%`)vDcDSMTmkzv0jB+}1NBKsOnhg6QvM=-#G4@IWPR-%7XRLjLn`i^{=)4(w z-g^Px9MDq<*`M`uX6$($!mmLNc@yU*Nqf%P`?`y^L%A)Ou~%9OQ-6T{WwI~z3uNq- zUcwOA%g8?LvzM%Wkh^H#98iReJ+G~hFW3x7i|nuZ$ryWH|Brfrv1zYZ`%rgL?z|25 za>iZ>J-}74H&zq{k;Li&8Y+s)pogbLc~()B6M*HS1*d1MC|Ym8sTnAWssgZF?tsf< z6-B2*4^Iww>LE43Kv7g_APH!wC_2|&bXAM;tfHu72q?sQ%nC+PG>a%i4j}4=AzL8l z`7U?0>h+K<+e{Kz3Nchzg5C`zu1lNtp6#7e?%Ev_!BxFw9tN@?Wn%XRL`$v}gOcAf z4};Vza;W{8klJ6$2hK+dahQ_(FC_3RSuP3^c#`QEVv_T$9b}5@ja<|fG0-%H4qvnC z#M-NdDYte4m(_-`k1CVLf|HUP%ZXN8Q4`wsmg+Z<6O{~CO8bM$YRlNCm&r4O!DSiC ziE_$i$5H=#*c-};YKJRpMx%aKPBgbnt_}mIWh^J^P|kq_potZTDnmI@`8Ud>#i*Z^ z6V;Z<>tsN)jO9dw%T**H*p`Sa136J~2q?rz=-sh$p1vN)IRF(q8KH4w{THx|>U?1c za3mR_aj90F@MBB={TbweoCBnK$AF~j5KxG*P{lJ=6jhaLNkvdZAs8r%%0>dBj-PIz zC>rV^u8{9AP!!cc)QOR~TURdGg`^d$y)DkJT++LD{Z3|f)Rl9-m$sJj;=hE zm)$uLq2MRJKH4Y`{ud^D2zy2A;`~dQ*6~&g7ol|x)H?0zu59b#JVfdH0S=;^{_!gQw*I+zeC=y4C(7F> ztc86h*&E7AHrTu_J(qzw>|LqLiuaV?2wGdmnQkH^lp^IKoyAvYKA`sK&zJ!_O67O?BxyiYPOfw2t$_#?+rH@ zKoH%RR*#vS8>$#y7=oKj!9DdStx?^q!swfAKS(%h;~%Q>DZK8p_l7G55?MIF>N#UC znWY>J`xvrs*v{Hx^izH>F;o(b%S>vi(HX{#!5cVf^eN`>i2CtUX3wyJ&`ilBn8K+#kfQH`zP1*Rb|uy#iq` zyAFHq$ZuJDhboz+ay2F@qXeq_Q<&-^#*=j^U(!b>VnYkM>6xn z-%EnWXlNsxIee}XvZPT|d06|Tk-cReJ7sHpM+4Q?_d>i^HlpC%q=_(cJfhHj$@E)r zQI=^hZxkI(hjcIgdJ2R1BY%|D_`g4{6*{P((I7QoblkEiCZ7z?&OE#~ro%(Kk z>6JWH1gfralF11~UG3EQ$Xe$)VaoI4gdWSFoC+PSToRKu10# zF7`o07L+`Uy~E;G1uXA3(?$@xEam`k*2bK(>Qh{+ z`S*x;#XgBkUy}k-^4!;pc~6{gQ%(AaOA(r^?<-cxQ#}((o?1r9lUocWPwp|aGP!{N z%0S68bx!G0ytMuRh|%v@9ZT7!pVr70W5)u3tbWJh_H0S@4eZrW45TK?y!75QN375w z6Q)9nW*8E#i8kURO$gU+=4T!^Azb&0)C~IgrMZI$@9%rlD|Wr0s7wd7#vb7!8)I^D zYElw%a;9tUnGHrNQbtX5t;=%4I{T89PDT_$p$ANR5bc;KcD=-^jb5Guy#(w`C@IFZ zE6%W8y~jr6_wV^On;0yW`-Qpr#6;p=@p|6gE;Fd#QPT5_)fuV`IM=r})iUe=+?H8l zwtfC9)JHWGAl{*yA1p7u0j`zlx8gY2;^B`7eZ?f$zCxv=Tmidc>SVjl?3NuAev7K8 zM-@R{+8d9M)me|y+INvpL=~wMsS2lOY!%$T<|!V}~dQvC4+g%RY%YAC1pkAG3mvx?5I^)c<$;F@Kf>)Kpx=>%z$kT5)CRH&rF?+70$8TOk zwSahSG@ib=uHUw*n-{zs0~@=X#wTA`%?_+f7lH8p`8I`&z^XLJJb4X5m`x$fvQsl5 zXeO8+p%C(Q>#(|*4MA3>5$}f}>vblCin)@YyqIpZLVA2XgPR-diAoI*`lAwAOBa^-h51wWyjcuRyLjy`b^uBD$SQ>ID9ey3+9~4YcUJ0 z@0sWPR;DU$Q_ogD+T2uRcFvJ|6Ei3I27mE)%N)&I8r`U>qyynbGmzYrdL^=M3g<~u zG|nlQ3OAWx;*HbTHTu5$c22pzg$%S+O?W~K3^PJshz{Og|g=V~Y@=3Um-h+$6I{tRX zMfwfyS8$Q^6}gc0-@&?&$2$kkf(sSnf*WOCXu1sBqW7sw?(ugpF0?nKUsHE{O)k>= z?_yo7jdvb92QKz7E~L}UO#|RULoWR9^LH^WoNp!#rM?Iv7wZ0dSr=LH&aw01BA0QI zG~L{P8C)bipuTv(-^;j=-_!(AUkoD`_5HJ17y0qd$(pWsNbP5!`ub~JTeBtbT3CVy zY5pqCm&|&9*F`%S>%DF=t@l6OzGR;D+-P%Ur?&w6B`B(2q&ZzIr*g3BxsRwj9`X-- zC+*Cvbxo?dPbaJrsMlbvKdSgBSx}`>o6W?NgLQMHi)C~vf;&vJ{GeUMN8tkVN#x@H zW9wYtV%pySKbM-BY8s-T@X^x9_c*`M`Z&wAF{d$0A(v*(}$yRnXz z0JjnC4TtM}Ju!)X5qGI}EfaFH`$At+`T8~P=q74_mpX6OLli{-yFue2M?!Rz;J=#$DeZ+)L`mWhd32f0?x0khW4L9^EO9YTgtAJ<|? z%<;)IG{{n)s32dn1_VZ=ZIlcFzUiDx9>YH61qK#6bH?kP=?*A;rVquojzoU>kmodi z!*aNA>A=sKpGgO7SZ-LuyhH86hWf;NPCQ2Qj5!ajFb`EULtfA-$~cKwfqqLu@R(^_ ze>?Qs&*-iG=Jw+~f)9w_nh2zwllW3CVNS ze2`!+KnddzBQ22=Ckco{CnXZ9wcO%#kdU>2MDZ7%kw_??>ldIo7NUfB=sQ}XAWkwJ z{k)V&=>N;LxDOJ{uV{-F-a8^u`kWb!wunZFz@Z;#iPAWULJJa=M1r}V>+gtHJCyMM z%KJbh#4luv&=!kOLOJvkEm0RINoIBjAOB1`pymm5zi2JqPZcCmq^_;^1qAUfF=8Z-g^GIoEOoF;E~~lW7pu4zAkO03aF~WQnVAr3OG`)&Np$Wz~&|}j=S=CgTM^l#0=t>%uU50p+Jf8!;)-?gfbx) zB}yp?L!f|`IPJ=}9Sjl!iG<=Mv!)Crv?!q(<|iN$=?RSeAR*q5b!iZ`p(P4j`F3d_ zF_=i`UNSvXBEjiV;@B`&8zPaD5Qq{AN}{>cEK%dyG)wUYN@WZu=BROPo~3+G?E;mW zVTH6xL4p=lbd*Z+CMO!?nJZru03fpnh`f<`CuKG`{cw!KOo zo|bs&$`6xnhViaFd0Eb(mC`iXNKlRF2ba3pozm}++uB>)7-zWD#X7^KuBd(1cyedV zc&%4(zJQ0*buqdMBXtndHV-A{cG40F)otzylvs%pYUjfsVS3lVb*>mvU%~mFpd~x!pc#5A9}|tAu|+btHdbb@DcKfa>6hV(wWN{t1gkt!}(T zTLD*FBiMp_XLwSF8M-jiSr@qNh=R_#HwuR@h5ZOJYUoS2>`cxb50!{mV_kwYd7k*D zUCABjX1Yzn5kz9KTP{lEQxZ%4X$j^vE^jO3v68tE32lNVeIBH;Dnp4Io20ZvtXnZk z+@K_4V`vH4b*}9;Ooio2B=Qmrdr`uG66Ko|w8Toca+EMo5-awCgbW{gW0t!qVW5Na z@p(eF{3F=Ot^rT3+oT-eMGW+(n`Sle5<^R#r-*6y~!qWYeipt+-&(v<9`4FMJA6F@H9ftH=NJyk%wco^``{%Jvd$78*s7(SU!{*S{>_-r zdtiuSJYbCGX|0#o9zwRUk%t-vnHE5 zB;Pm5+FJHNX*2SH@?e{{{S7b`?VRIdF{w6nB5Z||3h@+ah}E=xKiOxqmfOpXG(cctM0V>q=7CfldjMf|B@{8I9eUN^ZFD6E3e9dQ**3Ac~Bn$bGxu2Iiu1M|(( zoLi^Q%ik+AJLe-qFoB#d9fmaAcbX*s%!_`nitQmyDf2`AqcritgZE)W?w zDeANed4>*)NZOIS>2@^Am_ecW(JP1A1+ulfEE2YLE`>^uE9siE>smk7 z*3O0l5L62tvTkO1R~f`Wd`2s~x3g>hb9QG<9}L);v@l;56WhNVtkwPJP?Gb`8MfI; zGi=G*Go-F5?F7(69R)aQ?e+VG^w5A`vnH!mlc58$hclehriA`e%0T_9t1SoNS-gX*w8gGeD_Phfi!Ab`S-tTtfh=jP4r7Kdnp@jRsU|K@=YnY}KBtnP;xJ6p` z89?y|Ql#4j6A4|X6bVvzqJ(PSR9d3`*RcF@kWdf_^-O!-A&_8gL5d8!sYF8F*|INE z^h1f7eR5htI3mno0Esjr!Cq{i)@vqEWFo~DJ2{bHbV=!t6kaGH+aEzo_>2hSvE!i& zM@A7p7Fqi7Ux2FiPv{Gp?qT~$Uy2G+K7qcZc|r}B*Y=aX#E!M+^o1k~83y>ODW0Kx zy!e(l6hO8k$TqtL1Vq|3Wgvq1A_()q5*j3JL|6zr5kO)Hh}hLK$@VjVlp=_(yZZqG zk{DEqAO;E~cI^QgWTidFhw%{l41CoD(lB0JE1nG?yAWijT^s>XdZY|RkU#{{A4sG@ z%14Cdur~vUj)2JBEDHsrAS#b>5I~IGYiN+WLE@(XB6*4yDI9f(2GQ7a21@}%eu#ib zgTx;ofiS#n20?P|ivLO)qJd@L4$;D_KWrg7>A3Y_?eNGs>KyqTHR;)!(OG{$&o=yl zJzH|{1l_ZPf+8Qv*dm`uNXEMF015dH zq|h8ZLrcU5IrZ8F67@tv+trfw5+sx;;ddnY43P*Kt49gOPL#;$agLTK4s!Cz1&Kx? zA@#7#wt-8T0VT4JSezpg%CTWPK|)JO=nfXp61Re!3YB=7I#0$opV2BItoTC|Zxmbw zjH5~plB>WYd&2;@3e*k2>;C!LgXAir?3tqHmx8a=!&gl~7^m#h`UU`q_9%daAGNqZ zKuX7E?E(qh@>5*JM$Z|D~I~wdFVR->P$Vd*kc%#+KXmIAcq=3@VvS{|trd{CQ zckBY@lRLJ-DVwmPZKgU}-Z)yNd32bx9%?E*Os>Hf?bG&x1gi|X2=AWVI3l4PY?+QA zf1rf>$zWPSJT@%S2ofPgf_>MX*L4wq{f)h>%r2NnNQb0kAjKAxP@SAgO9YG!(=>yG zf=H;#?9=`P3DyIoxNkR=NEn7#Zbgc1C{c4#PD?1phUH7WVHhS7><9Kdj~JjRM~Vk_ zaw4G|nz9`!cA$jpR0J)NIX27?01{dvp?+web_66?kC5V_T?CO3`&;fpirpxYb!q`E zp&JV$c^Ncu^!(P5oZ$^4c@U`PABJ%&r>FiD89>Jb^IpQZ<@6HAEk%#hWZY6$!$>ZN zy94QIGOWZ;5Z}@+1CS~NskB>iFsUDmTfrS#Z$9MXYYu@EiVvd`8V*;36H;*&!^XSA zy<|cw_CR}wVu+qR4IZ3yI=Omarho)Yd=^v>M5{EkR;o50Ki327cu1oVDSOEgdLc`WVo7zW<+q%O@9q>tqTsa>G*^LdN+ zL?v)S7OEI175UK*w2EJ_Q?K_>mr4QxzQURY5;_D)I-mT3NW@K$odgNxNd$@O-Aqf2 z33l@N2okkKLg#Opy%HqE7eJ!+yhSsS&`zj93GpeE&>a0lOH2xOD#o1r^+ZDct7SMJ z(()@&!tX-zCn8ZYL3Iiwbd*F+?=K)>a-9i!IKPp&PCtj45V(pbXJ4@R@|W}o8-t#n z4b+gaG|(Z=7sk>AUwjN}xpuVm8KX|?ni7^Fy=YRE3SNwdr>Ca*+FG%3bjeY3T7N^Z zyu*IogcACT z-)IS^3CvF~A^3`MjEr%f4tAvih`$ncfbR6&?i-O1PqO@g6dzH-z1ZJ_47Sn>d?n4C>S%OSCXaepXec{w{zl{|+lXNtS=*^px+2CTxRfm!<-m>BWY| z(R$lRUl@Lmkw?^v_SM=`P>dr2w)<)7f(-c;O!S{rA;?Rv5ai+({<7Cr2#&C-1*bTb zf?S0!4<2hTL%M%g*O?ywJqB|isi}@H!=ZAeuNV*y6+Y&&DIO}SM?)-5t3WqnTIhAI zalSmeS*8U4@>BZL?ElI5?}RfmuH2_C7DXI?DPK9D)>R4J`f}rVn0RSetgM!`Mn%o3 zR<)Z`&76)Y{F|RB&*Nm#_{_6hDljT6o!u-jrT*Xi3}+i}IIXvQTCMs&u~Cx*Z+&^@ z?En3|0vG%A<#77IR@is{)ex&sE^^xzm+7`&T;x{F;5U!;e+X-t8EtinBA2EWI7ik$ zVwUXF3URTAYCX(m`K@o?&#N*}8IOm&*UXZ+cXU;_cXCy{hn1Z{qm{ED`bsgYMGjl; zl3AH}O3C^#uh=BNU`?&QO56Gx09~zWx^8ii?b?G}nAy|J8rUxVzg+meZ9TlzylY*i z3Uc)O0%F>GL-=ehXo-HycgljDgOJ;IxN8?Ln1uC>aPBgxa2@wwZvtpIEPn|zu*#Qc zlcvm#^n`Ms9(*s?vN*|T<18|>-KU4p+qIn91qIxhm9)_TqNQS7$ug{}-uk>N%nDz0 zg;`;CX3CV-S;-~x-G!3T;s`Ri`=L8b`Gq5xyi4*!OFCRYdA2a03(F>Fu!@q1+3m>s zP0JyPaRr16q(3CFK)0>^cXO|M$hP9jOQzLq5^+#^mUR{+RA*5_9x<1e&>U7O<6#w7 zOC*>Ntw-i?A#iIONF;Bynmg%6C6`R2O3o#n0~uxkW$cJ3kTLy2IPXt3%xF`!qQbi@1nIi08%vm?_dE z?Pw{?-VZarrhL2oPP!de^k`U7TE9XvI&dy&#}z${D+~?PE@(&3xvZTK2J+BDzY{?H zcGkR;&z-Z6D&;DaQe}89+s!43^!1;uySPCUcXPug-t`okSb@=Gk>aySVYVkM;=Sn4 zg^M8~Egs!2u1npN7028l6szrB`@H`=cX7D$;G7WXzCkSqujX(bc^L2k;lRG z`)*>2oFj^aV)%?LzRN0sw7BwO;SF?}n|Quq?rQq{f+Laz5a`5+BGo?HLpU9#&y|nB zl+9a>b5C+ExB`>*9hxk8>YuHP4-C%t9`y;(JjMiUo@}oux_Mq%|9Xh~8k$4%h=hu~ zvR-qDYovC8Y}z(S6wES1rLFdIR8zMO7oV^=;5Awi1$K^av%gFFBM;_ZP=V$>(+dB* zX}1?px{lU%);{u3;qKV}ll=aIcsAIbI`?a*oJR`wBkm_x2$?!`KT7{kyzbZui^C{> zDB=(ttKdfA@__na&0P2iyU7mF20usrGtxgjO z#^u6lY~;r%p_zY{mZ&@`=@$(W4Mal!)jIhdNa#_*|A^IDA`y6*RSy!XdX&)3KTk^- zk4gfTfJ8HqV1BnQ?)My)(vE_J#ZjyCL_&Mn?;kYBKPX|GZ=fZTb0yPbK_a=L)h%Q^ z8;e+wkmaK-j#?SKDs}r3r8oMdRpeX#{eC65#S?7D+RNM;*uMEFj@AxIuCU@x(_iF{ zhrjqs`bpwqOVLV;RI>C}e=P6|w5;L_wyb8_8M=pidhk2Dbvt44jPyNW5i9JizRzfR zG`O8C+V9NYs*f__Aj;#pC{s$w?C5Zomhtw;Ns#PKfaSkv(qp_m4kbuoFYE%P#nG&@ z1Sb7B;~S`mzadOcg;`~g2S4CCcvb^}(G{^a<%5I{C9!)FS9%!f`oVx zNH7=q6%dKi<3^Mae#eU3=y0BvnD4=#d*{@MzTwa;?uJA2#rM<5(5$^uDCf*^0}m(esdBX`ZlAdd z0*??xzsT?6L=A^@dd)4Cl{It}D{J)fMPi6@HH9fFk4tiy9l*;gTfLmw&Bnh&9(Z{P zdig~wv3Kj3Zn))V4H9~5ln^iJNK2$27jqZ6!P0Ig@7C)?;I%@LZWSzvYASL3+#y-g z)tc)Ii>a$6F!*-fvP6GpweHyD@60P3T1U9b6Jm8Cj8|$uXtB)7tF*0-cxYLkaD^v$ zsvb6f2Ex_4n>QG$OV=Hh9TugiT*tluxLi(mAz-DkaFzZmttk+LYZ@qBs9BT6h47kK}Q&alsgc#Pzdtp%?UGpm8j!Nz4 zdEK~VxnsKUoh%ssNC(Wx3tR*IiZ#fun|_1lukpB%AVEJDFC`tY#-lVr3VUG}`t4G` z8^q}f^0cU;qf~B)Zqh2xJosto`u&N8<+oUyD$#FI!gDF>CXp!3Gopkr6+y}?%@Qv? z`0LU2eTYQvEmlnf`YlR?EiI%a>hh$iAR(tD>UP|sDU2Tc&FK2RM1paf<@psPa#3Q_ z(xh92qB&2461kLwar$jq;=Kod3;OvGA`y6-rCN_pixT>!ez%Flx|0T!V64T8w1~=S ziO(MV9S=bwj7Vs2v+nE!31KWqyjjXBClcC|lC>ZaNJ*qsnI*nM2ZSrC2lFB6fGsDa zY5GT?IvX21yz9Cu5?3+JCI>y`FZ7gQ54X-B9Z-4~I$$((z#Ubj=Zu67SO6dj1Q{Qj z1nQ>icge|81Tj(|+8K}Nx^(Z!pD+FhSEWY;WN)4{7cR{wETRx3vg=*~qMu=N96|CB zr0CHx8pLq7kim@wkURpS8g1h*ssxZy1UVMVsvsc7lZyWWh~~d&)cmRn8l-Pe{xU6q zFe?ejjXY_&z8XLj2ol}3fPly%Z3+iGv1bjwM9W*d_?c5i zh^>l8>A0?SRS&aut;xdb_XTeBgD}1|`f7-pNMS(aOHH=a#=(LNsk(~7!7v`mhTwRp zTfT}O4{w~5e9?=)Kph^FI@~xZ{lfI5cH!;EW2;qUR25c3q7fLFhf)LcaxJZ5I4KdI z0%Imu6Ze{FgM{oQn&XAlAtIrD>^Bz8F%~5Z@keQi;!~1j)n{;|TrzS9D_HKF zm!PVvhU--RE@M22kNMuSelI>P?8T>rV_jF$PYVleirP06dFUHRUCK{MwalU5^e9oh zOT3aubn4ZD63P2eB5yKCn4*8)^;E(v!Yflxka<7$3xMB(;K`-C0#`aT?GZaH2ud4j zO}klb;S{v%RIH9-Wj<5~yO2r*S8j;juQsWcz-1erqZ(nNq4PEV!{nL}*rMv$LivnT|O*L!r=A?e3bt|E>X00r2B%QS~U%Uk*q+3v;Y1eC7 zqOqs=m+o-kzepqu)qck}vmmSsCGMn(g5M_BbDwrgw)xv+|-T*0S(*T`Y8wNx~am zd-sVMCu||lt7(U54@{{!$(X%T=uywQ5`h5myZx-Cn>}+R)}8oWedf^w76)foax}6Y@LQXD7>hN9I~bPte%- zMCZ${j)$}e?Tx*^CtVvWgDKz^m?X7m>D#y{L;4*M2X%S0k^gp)HtETKK6HG$C_|1r z{l1s%X~C)Y?_>6~j53aunKP8Je|IblhN<(*!Dy3jFE5g1NbQ3Z_dzM~_VC;7MXjqz z+6R_?gDh84yI>8m`M3IueykkBP)qXj+sGBnu_>^+F*XHO1e%8m`K{B&8g*$q2&WSA z$wa*0G?<9@lR)COwC_JjCHYD^!BWIb#e}eYcjW}ErN6{UFDA3#fy={u#>=*L7V8P3 zw91n8Ys#DqrG*-1poMy=6CBdwyJB^W=P3gmV77T6T0P5x{S%^4i>d16jv5CCp!o@C zOc4UrGXg?_0cacm8Pe?d7C-UVzM0GAA zpJCqGo}X;bzb2j{gBYSfqzJQQx8o;4LJauv@l-6SIWt5r4X2TH4Jm<$Jf3nCOSPjJ zIni;Qet60hmJ^Nqd4^9~2M5T7fkHe-ezGI~S{PMUFrMztQ9BdkGt{RX11CZg2ji?i zr#(MHLT5P%m?;^b)qQ+ehLWQuD;%30#|10K1j9wiHG_T5F&vAFz~a_);493VR|qf#?>;jWFof0M0AW8Nx9iEa#rY#@Z)k<+Bi`& zeTIZ*D7okrrFytMBwitJvJ3?ivlM8fEH2xPHZgmr;~HCNm_%qxQ{I5m6;$k%fl{g0 zOO1N%eeqa?yGGej!Y{#!nPSX>L%U4|IG@P~S11t{(O%12nPnI+`z>Wkk|)hqI8zcI zTs|QzL(WG;UXfH-1c3_pVVqMD^F{|k?0#mk4O@nXOgLX0!x1Nf!WqKQpH8hH+ph?MhQa#l>>d zE)Ys;!_K0G)L$JJppXKE)(quKJ?U;g8-ALugR_zXX`$Jn;#&sOE)c41!)DskL_$$4 z3dK`IDd3#siyIut6Vh=TZI%tCRYP6i#E zl{A4CiVGFj8ce%DDBGT`r-k62vP^g?Dn#d0^3*88HYSO`3Nt0-y8=>M42GcGO3U^?u9MsmIrLfSAeK4yT z``~2$edvS9wa^FQkF=>s^bCu6LGgz0!_CE1hi3j3ruhw4soN3fP=9o+pCEpQPJape zpxjd!?n(DSJ#!KE!9|qQUsATCy-fKng?S8ww6qZWGxWhBP=%?K)3;-*@Q}&NJVl?8 z=o>{ejUR0$dK>EG{1u#D0cDzo!jjc0#RqhH3TJ_*Fq+2kDU!utAB>@#{*u&@#?gIG zk-Y^Rxg)6;`#bc(p@1`ka(YL$3=f%byhP(B15VIn?1MA;@n)Ri3a4Ie;Yyc4<2dP6 zy5VqrGEz8kp2Bz*L}tVeJ%7EPqrV-r5aO!D-g;*IIEktp8#$Hz?j%exog#mH zST0!5;6htaB2om4qoFC{uArkn@)TCmO`*8tz6zQ`z6zT{c}46>H-(m;%De+YQdc6R z?juT@Jp!t*0-J*6%67*?R8y#@h$@lDKoPCr*PDsT6i&`=;2AGy&j>71DG~uEhr)U0 zDXgb)beF;s0Y{sNIJs90W*mcUs>}{9#*inK+#-65(&mi;oCL%vbY<6=ab}8)h|@&j zB=DQeIQ8JDeF5hKjiW76>23i|Ife7qQ`kh~2(M(V1{}j`#4%jay3y7Vw@Fn^034ke zXOJi@dMx0qMw}!!b`~BoS!b!pKLl_BLeRff^QC?0F}?*HRR?D!X=t}-0Y^2Mc7af> z8=KjQ2(e>D<4`CFg*cn|!_7j&7dzdMj)USi(?Y`+i*LzIyFkdV6FaPvn#y%m*O!b_ znk!i;SiuympyaBd6I778H6OTKR5%Z^7PppMU%6Na;pXX-r8}`}I;l^Ds%E8jNET;C zm(IyrlL9fqL)7Vf5*lWl{O`~(!MJeO9{Lx?@XIf-YzxJffESr5z>c@W0{v^ z!InPb?$$aJKRz%{ax2BPYtjC&EBa3$YiArccbA3Pd>+QA8hMBu%8e8nzyYwCj;5_o~56 zWDup+B!Lj4D^0Xi#PgU0L_4VV=)$fs6KxY2kf@F#+QE-86O}J^@|g+-ilq%CzoM#- z0-SORXGtGn42@HIHERPp#su_~#VXZp2CI%Fe<= zCIjsjH6czjg|nNlLmVo0Z0S{u9V=`tP<^X&>K4crb{$(Pv5!#KCx?-0iQ!{QLg?_Z z6C(d9pcOC$o&OmBGSQL>VndxEr5%AEP`*l)*;x!)wc_GF!pnXB_2BwwiF6J;0tC*% zM}TX6VCjjto%m`PZE_dSkt+Sk$Ver8F|8Cc%ur%HJ;41BvI;kBz1x{`U zk0{KZ>OA@pMV&6v^nu4{)=9b_1}zD=5du%3Wqt4|Fpr;{2Zu~gB#lcnh@+)&?)QCBH-wrykr6?&>MnbNu9UKt8H(z7OPm{v1+j_c=s| zTlh=+<#R~Bwf~fI=avxN+ga`qw09mnQPAY^wol0iINhi0M9Fn+1!jL{%wa;5*@r%c z>tY|7%&U;GXIM8Im`9pbY!JeXM1w3n% z!-6RoKUQRWbu`Q`UN7CE3gyV|9M}50Tz7DXkg@cd%Kxa~SBSFclu+taOWvfmIuH03 zM1tMh!gMmukw9qFs`ily?bD6dS@*@a3jY?~Vr2-W4o^L%TRr!Xw`cObEA=5LZ`{I+ zBo(6xbVEya1U0MX_^Rn6-*?0oi zZX1)cO)tHqN{&#fA`}pJX1rBM$DN69CsA={)-$j)$u>ygAiHw{OPdEL?~rU$xo?A+ zb^MTeu3cPv4YgnNn_-Sr19PNuRF~9q8|Twoc+u#GOJG>^_HUkt*gbt{p?U1 zn8**s5)|8*Q5-Ub&P)`rhuVDcxM8<9f-I&fceC%@W(~DDVeu#R>YsTWRum`2B2YEX zhRlO2_ZLNipllurYPUhQh}H?+5@v4!{2X^a$24&(<`($Cs_!At5hX`)3ZpoVit^UQ zJk@r=pFLsf_h3(4%$v_S(weY>Z43A*Gmr;SOp^Ff1hrlOO#FRFh9ey9<;YVp7t4}VV~vdwj9ql7m-4ToIBb^{*e<2pRg36`)^8Mq_3S6_2|rP$8``)> z9bz*s1RDAxK?fX868S{)XY9LFh)<2m};M^+heZQ$TWKBSJ zM9_q-@PTrS6`W^g$50G(s)i-qU8FskI|Nroz=9akwtd>3_l!1i^9#*mFmdiV2 zShl%Nye7!5lb>&_V0m7b@7Ay4=vAjMob4l?o3x5Ed=&^Scj0@Nyk6vTL+9N6ednrTGwCwA=i-%=8&n$?E5d8!(qgX+A_=HO)2bmIe0f z1@EKFy{A_cNuOYR#%|Rsj;w7m>8;H#HBYczo?yEuBZb-Kb~#wI-bTG1yzyse5)f1N zLivQX_(`FMDh;i}%vK4riJRz41VgYTR(7Z`SybN z)=2b}j?TPJaL6=%Xd5KYfGF@85@kjz%tXe6MWxM9S0$9GS5X=AjtY|B0SpzXwa^gB5+tLR?-meIY6TkFXz7p%m3uARv&yNFjXCwtc7zom>20JfU+)Jo{s+Lhz?d#0zM^`od72^c}ieDgLd87gXO?^~% zv`Bdagp{-p^P<4t?>Sh017-O=Dh0t;svZgUA<+TK@*6k<%|!0L{F^($^1ig?lN+Qy zae&i6;S6x*4Ww}jUTG0WN8uPFAt4IEksmJ-ze3B?II63H;$py2ZKN!}M^#GWFslSd z5$70Xpp6{48K;LY|G{p+38QfiHAv-Jz)7BmemT^cC#P}hUP)g9j{GHBCvBc&KaImU zQKVac7P7(YCuU~e5crRH0oK`yIPUvZ5;!D7msP4B3(g?UIm$X)IT2=@{=WP$1)Qbd zN)LtoeTCsl(=HGizh5P%h4l4;3n*kjA@(-TJhRY1Uw%X`oTb`E3k~!Y&efTAfsksy zN<#}VUkIL{&{N7l+c*o%LPLG|^NnzpY&$J9)K|EG@sipFLdW*2bhMD{rJx>#UZ9YA zJ7)hz{r_a=aD@)$NmL= zxj;EpoHK8Q^Yk+|t%-?-y|C%j8{+cjbZD3-hM|ImSx*}MdIZz)c>2$RF1?|YZoRRg z;gvw!`znwiQ6D5y=WzCzi4OVl#}+`B z+Dmt-+(xO-TfosoVVBBw=IxR$)G102OW#7U2mn{kBaio|~dj)KOKH3*7lzlMhY z6LH4>rIOP)`tO2)h%<;X&|jQC%{Zrg>&_`^KuAjqsa^>__mY1Cq8y4S$C>vhO{6gz zkx2Lfi4H~S%tX28iga&4sFW7cHwye80Fmq<5*_t)@^1ZQ3Pe6gPfygA_H`C z##?avL$u!+ppy$tyFjSsFO`uNQrWbbheGpFXxJgnakJ1a@0+L)gQ)faiF4r7gE2UYMr zOOM`|!szbXByt^VFv>rImXgw{j8f8CON@UNL68X1D$ySl6nn^yfEIW0L6!2L-u%NC z;brz?yuEpBgSR&a^+!0w?U>&L!*1H#@qny8+0oQs&IU;7vu44mlR5J?84aw=)E@?g)9zt}!q@3e87C^ct zR+SB8I$|7^6>#*n*2mj&HoLx9ETVeaGKhOHTW#<?FcvFpbK`7bVkTfF+cWKfHtC^bu0@68k+esS7}-;@s%BLAD&}>bJikT|P07dG zz5sbH7g8R*pcLY5*R6nTG}bO6oamaVUTK@;7awZyO9(Z*V1IOs%P`DTL;9cG2g<)5 z!1SNAp}&qgk-ySbnmtnuxsQ_GI5t|RbaOAaaX!R)=Sbx~l9p1XCD^`&(!N?kX>Xx4 zgGGYvx|Q%7+a@e}!JWtlNeBW4Iz|F?6PBB)0!O(p|EO3DK_^Pe(sO&7@V#k zv!Lu07 zCiseQ7S!L%=^;G}TJk(BSC8R456LX3csI9}lM3NGvvHbF_)zKuhfLu+hC|%(BA7u8 z!q&)|&Hf#e?@-}8?A5M97ZHT-{BBFacUo2i#42DGbetEMHf^)5#qYLj=Ut`4cOJz< z3U`>8y@W`CUIBnxo!gM$T|um&H$=!pfwPL*jjSeMJzyLzT6~HhGKf zwTDz$)9`fn2=z8|x}giM(Tky7mMq48vDG#L+r<>P6cL}J&&tw607nFAX!Tat1qwx! z0wz9-n zvA8wcZ58I?;^L|Bow=ke@E9t5hbk!AOQ`Vbwx}l;B4lpCx&Lt7MDYrG=6C8!=@o2+ zE7%G#Ugc)T5Wl3tcXrXnDT$}Tcjls3ls=Tg3p~i&_^H%bFmV#BAc0rRW$&ga zh3`OHn%1lXUaAc6Ed!-e|4?dtsJ+i z#yld_SCSk0L69SaAv%VL(%@Z*ws*aF*V^DnbQiyEk!QQ^1dK#MDFSYIhHRnwq;1mt zE$YL{cHF`6KEnJW*?isy)`p#(Ln@$se=74*!fDakEx)7KDceV0yuW6CP~dA6VGQwE z*{|HIbo8u+2r948oxksuZ;FLM8-zDTwaCcKmLm9sV)a$6gp-QpUt1R6hT3HsV; z+iG*;&s^y=d@ z&^VemSv>*A&=Ya$D|BWY@%hx6vw&kT<21SQ_>O>c5pl8?vGr!04sLabQ%~WP+WL9Z z<3r_2r~Da!m9V2)LZBWhfLNnys7C0IO5)j<5Z=Lp{*mokm}A) zgtNL;5rfGuaQ*l?(=HGi?kxn3?aKT0dJ zwZKEB<*@PupPSIb18(B*6XG3NMfY&0H!{Ja#c+L_P8-STO_-o^F<29+?Pd#nD3IGU$;y z(=HHtsP#UnLG=RrHF?qk?+C`LORi^+McGZ$m{o0}(Ea{?^bNLw)O zy?akCt0xSOU$7~%6WRJi(iG}MH$@CIMNAC#$`;$V=B9{T<)k?YLQS-g_Pr{g9C%Bq zrZ9R7-_jKqe##nw794>@<&TxC>57XVz@_6lL}RQbj;fyS#^d$^oNtH|zM8GVL#9i| z?{3A2Q$pce1Zc{MfiHFU)pb=>=+HGhKzBWVjVK4taiCiRCg!x@{xZw)(a4QUGYX}6$d z&=ixF;fU%iSYU37QHf5P2Otzp3w8UX3XlVDKGhUaKEefb#WTO?C!z%>A(5^|oJtc( zpQdJQy$SJ0sl+qbf53?91voBQUT zRGA9Esi$yaeS~o|PW>0=WWbS5MjT^}!HlDPmRj=&a2V@ooNX|odIOFGaSBt}HD;VU zZt;kdNa08X#b%r(iB9>C0p~J}Q}RhwuLB&XC)5D%BP^zIe7?%21CDY!;)tJU*U{F| zJqJg<3};E65J%NL2S+V2?E<$b@ey95h2&prP)Ixjg?yec*3&}r25{6{a8>{foz$yy&t+g!U|SX73?o4#|oOlnZAa>ONk$~maM+@o(u1XjHIe| z-$z*P<1@{LBvL!?@?sezQi~d>cp`qNF1L#800GcQA4YUCmuTz;#|n}LQS-g@N1ZV>PUYe8cccqMz;G# znyA@r5)w_Nhz1Lq%|zo?I~}nDq0h9C&sUXhObmEFMfAZ(*h~}Ud}l@gku(B{j86<^ zBIWy3)#y9m`RT;c?E7wM<9-F4K+5wsvTMvZU)`b*Cz`?u6!;IIhv@LtPQ~)Q@P73H z#6W3ZRZi^9@Cx>)==pvFg#K{IG(s#flTq3AKl=sz&d2X7OAgX z9RLA$6A&aRot>4gwl*%)SbKHv6v`z(^YW4TEkLaWsO9sN@dweWliyAS5Bj53OQRm_ z26wH&pe3kHXEXm$n^Uf7Ev@og;Ku%>3p}OE{!mX8&@()5-PUMavZt^n3CK-c$!f8G ztOi0k2jKnqL_zXIIApr9pRI|kfE)X`3cRtOIzaK1zOfJZ&g}jW;@sk%k`68Ynp!*? z9;yw0;Nx)QQ>g+DwN5y3tj-|LS;S#xs6)&+AKkX4yFl386dI>;wIVnc9uH^=laDwh6i&ndE#hDUs0Car_!U(SDR7{u6BMM)K#eU#K$Ub6scJIV(~*j5r34M+VN06cS}tS}s)ZP~CCHODO<*xyaBBg#H5HpICH1aL z61)ymkT5_1|_T!J~mxhd-q&QP!n;mp_CXoGl1j z{#0ss3eOrl)%a>k?OblpthwAnu!aAp=48)JJt$p}dRRY~>#^Gf@{TuyC)P?=7Hv_t z*#YyX!(Q2j3VJCZ51d7=ODR}-hgu;>>VU7QGzKKU`gdBk;#l71aFUiywSfCSY1y11 zEt`6^V7_l_TDHNFgCH%NVT)RgQ^f8A^Ir+!9hS0wjAh4OX@5QTis9I;B%{Nq7)^z{ z)*!d%Twzm*+5+zAqf;@ff2F42H7;A!iqSl7Xcd#a1hl*N7l|OXI^07v7RtuJ)B%W&2KVn+02D`uMX&e<;qfP088T8BxAY zln(M%Z6I}8I;8-w=zAjKd)s#zgsXQ0xX(ovDSI1rDLIY0zqx09jUM^|Tv)ORyTzP% z3nuMVnM;*_;E7JpL`wew+^*gHEsy$EI*_+n9``M=6nR%FOR3W8MSG#NR{|()FPx;e zsuwvu7r|N^D&*JPxz21)?SJFno9jOiesBH~E0a8sTQ~kXu7>0+l{=plrEAWqXX?h2 zWfk|4MkB0OG*cxsi1L=wt2~CEJ&ldh6Rhp`hu9mbib)9C3?jzfCwD2~VK~}I-tXrxu z|8_jT)27Qs5IJ*s7%ZukES21**HQyF6pIVheMU=flSM7r0J~yd8+h59GjI&Ruwgw!#WXxg^#MZ6Cp z``>S~$;w<^C)pK1f^{Q6GAz{~l4%u3@|_LKG4)U{@_87faf8nw(=6hB5~aVnt%kQ> zu0)EWA>0-bygTA6v|AH?Hsw@8FMjv8+8*jRUxyU0PSo8Jx(G>)v9!yv26;*=1b}-Ow!^+ZGOE z7eXxKNV=iF2c#Jd$wmmgxQh*)79(C418bRgA((6=*m;NaF4M}jvrGDWsk!LKK!GF!x{(uf1nGjS`M6aPNl)q#^GtxxL@1zV83_4JH|D(=XKpZ*>7BZ${UQH z8;oCP{l>kF=VJ8Sjd#M^Y-fmOh+*89DP|8+GgP@~RZB z(dHPj+BEjnrR)+! z)*N6D3V12ad8rgM3{2% zv0u_QmJf=4?>f4!(?0Nfl>8 zrxjeW`xMaF(P!qDJQuG{)OVWmI+Ops&uX*ahnm2d{TI4*EBf#~pl{ppY!84@|2DlO z{lSpUuUoR5XaAe;_g5W<1ii~Y_FJd4jmFa(oAb_)|Gls3^P}+D#*PDahhBRBzq-j^ zO15%heqhL_>{PO;O$e68vw@LZ!LCnHJ+%zpL z6vHNO)hp}D@7{lhO3L1jSP%86L!%GbT=8Y6tKx@J&uwfj8T%BhAHI)$@5m2q z8UF+;!n^&Yq^g25X0iW}f(-}u$s>%YFkJKXH7_<)<6 z+j|ATrfPfdOKm@u?D}%@YmjKu#)dPCI(V9PoOy5B@!Rfq2x#0)A8nj0C zz4M2A+Yg!X;MCtqxOsoc`A6Hp*AiZLOkMw@ZK}>}to#(6j~4H@>H@3$2NNCnlJ+;H z+`RCYA0?q@1dMy|;5TCn^f0R#*o%gQZ?j}7cJhC)^iS1a0xrDXe-SJ_<9WLaxVg9G zk(lZS!hjTL)x8gzRX@d0t{Hfp>N>w~G+@<&9-q@w0H?u2(Q-mNe1yZ89fj(Quc_weE?HKp!-m?-Ycj3Jc6vfL^) zXWT4&9P+>ZL-~7%2OhgQtru>prb$c=kU#QY)ol8Yo5?1v!wGg*u*s5&<{9mN?47Z% zIxdTLJ1u>Wo+=ePpo3#C4Cw|y@5gqS4aK~lvAc=&BTC(X=q>-dm4CwpSpLvK&G`;dxH%kpOy<JZXr zVF_;H7T&~7{;Dm^${#$I8hO$-d4n4~Y5u$8k7oMxs^jS+yZ5L}Ryf((p@Add=Dh=d z>3-;4ufW^`4!g}N{Gl`MeK%yu>(hFwkaJaDP}bv?@s9Ez>apT=$Kh4rSnD5m-TA+l zA?P|KRgX_qz;5f`J=petu1LJ2UrKI1xlbK`{%p$l=BBX~ zlYYa^<1H6{^x?nHZLEu5_@1)cA|)C#FFFM`P0<%93orVN#|A7+!uCJ2s*GL!@8PTr zTB+%*r!Q3UdA`*^+?@W3HGlqo{QvR7Wv@Ffo4Nfc_{6d!9gg7ULTE1I|5`8ns~eQH zV$nwA=)Z8I009;i;bukjR^0i1RW_&U|Ko^dEvwe@r~~FJ0(H5m^p0MOC&sL9w5tAJ zr%oNGG<+4O*3b{2?Y)DLInk1K$TOu*;*>%b4M#W4*c3omxMWs@NP`jm~@IU9B zd1p1V-_P&={p@$2ecyS`dCob{bIxbCJ0Q$?9yv%H}~`qmgdfyhfU_INNln` z5l`2CM&=tbRuN66kZ7{@5>56QqRF{NG}mgjV$bUjh^O4eD7%I{A<^XTz!A!`HqFN- zJDzBAW)V$>DgLI&`7j5YYn8;y{2|zs=4BF3*OwDbW-5WDa|hVLAN&oKf!J5xp4x5N zk6SN2G=QzSh?kXmtgCOiV!jAF4&*+bOIYxiwwc1?9^gIThDYiL9Rzq%(Y~OYraBRl z@Bfw3roE+?WJHNoD@na|L-(S-T8_oWkG7v@ly}0Ue463!R2YrUln?kjXHZ~`t)H|2 z#(uDMziDv|Ezc5*8`vC6$;BqBJP(^aDb?6Ktm3r%Xbmg=;dC*6e_exdw)#t9QM4^C zuR)L({x`9?l+uCJ+_2b+zfCXy6`TCyt2r8c9zY8gm%u!SzsLKhVUt?*zN@=HxT1S% z&#j^ukOpF=Sx~2$qi;cs+k=tA_k5x#RQ?fIq0&D7nhIQ*3;v7nji&#i5P<(ud#4?Y z2-vQfY_&^kyr*UCBx^H8fNDaW_Z%OoIr-pDtgFpm-~keD&f*hI_9UXoOeLCYw~6M< z=K-2K?I+lU)|Ix_3{twf>irr7QPCN7O0QBw`5u6fN;U}dbf5ibR^}=|CC6?3%Q%R1 z2Pd%TbM>Yb-gNU;ZXCzzvpSHs=4^Mh*VA=xiF>;cRHHWYE74>32o<>5Whv8(sX_1M1kwmmi2X3dw69-(pTW8vi{tknlt=1FQG0^P*334 z`aAUz2Q()RRc`kY)+PyI?9|5CH`p}SYrEms5=tUN6%1o=VP4v9`cT??3-+w)y%m~^rB7Ev zBP@3zur#<_7fynPPpACorgABL! z!;kCk8~)d}Q6c1?Q$ghjXR`MZZh}vvVqP2$UhDcyy~I=jZc>!y0LfPgI z^#&t&8ZvD!T41GVJhP{#de_=7q z8_d;`Di=-L5Z(wIL*9-cP@{(Y2lMRrg0L<&dhoM*{RATgHRJ~Pt0v_JUxceEQm~&1 z&2_;!Y{v3au^AKEW(iF${{VK`#IL|*AXg=qL$*o=*OVDMii zq~ep{-Vdw>q>MuHANqeoaSN(CG`WE7ZC^w7V$Xp-wq}nHD8B5q*)1n6Q?%{Z8MAx} zTly%ycxN^~zDd{|SvKV&6TY!=ib zbK?Z>vDqSciA^W}Z=w;j62EPg8)-5wgsHCkBorHTv7LqRw(P4=(lrnv;?C;{%t+X} zTrZbyy`G&hs|!VaI*DwnZ)S@9Dm&Jom}TcfEE z&u}TtZqBuKfn>pZh_ydm(MWdik*qvHGQ;qR&qSwJPNDaVKu=+l&Ac~5ISb`ky@r80 z)p@*BcqfC5D-}sZQ~H`{%3TiQ%c@wSsU191D@_BAF**nFrGs8z&-$lC&xS!q@Tsv6 zpYuaZ2`1v9|3-B6s>)Y^Fg894aFulRYY^9q#&f@ETFbii{I`ZvUlY;ROU2VrrBtGt z`@}Wj0JmRK_QyLZRB!apy{GjZD>GfYX>p3q zColcFvn|Q^4=i(O@Na)R>>(rhMc`cvm9U0?ix)~7m z*<-dMivO**aw^^U**21yrH3qi3Hg%QG~i^#MM8Fa^(M_x?KE{R!t2!1E?a#mbezvm z)DQ$%^|n4*Ar~ez+=c$N@@hr{^WmM zs+z13Bh3X}>i@Rje+1QF6e5Ii5c~c|pkB1rMx-`$%1oC+l{cvBG}8)*L82R%RNYVo z#HtJzgp80UEj3t~-eta(fq6&N4oxpdkq=%@>53tm+gD9yLVvbqCNiVkV_$RUnb9&e zm2+a4TPK9Yt2a*8_R90Yo#F4cRBFgrVS}bg zRqO%(yuMFP(|?+bx{bng$ZU4WNX}`59G0R}uH|w&qsy(x+a_>340b%A!>6*=L->UH z^hYhjP2uPut)-9Isr3(1K5iSW+y=_rK9Wlu+j=@>BR?L~r)mQOUoogv++ zIcU(2ws!R>-U7byTfb2?Gli)gaWxX9m2bX2WeuB1D3^WusPMihXL%GS#^2IWZF4Q#wAMK2 zWA=>Wd7OS^ifY27xW6YCatwRCbh_y|(e#RHy{-5|bL)O}Eun6l(B@g!dBQL|2tI2| znqrqF9W84|{12_$SJnA=5N6#c*jP`WMXv01&0h?ME+gYX)ggpObX(hStr(lox<%n~ z4GVFg(PMUF3*Si7M7DphZF|nqu(hd88iK{-?abrfGz0h3FykdoTVK#MHZw(KM^S3q z_B_Hks!vn$0m-q_R^z`puqg1KXoZI`gs}Y-F_}()_vZqt|3e4p2^~eRh86waR^X-W zwFd#N!2b_xKwBe()I^KHGe>L-lzCFzH9V? z+8qkZD0?h^a41y^4D4E6vig7$Vf2U-TseHA#^c%-5_ zu)w_rotRdBhas`~D<(T=;(|o@qA41vP^1;wehduIkheGz!v^||W)ESs{`0Qtqxfj+ ze9YqPc8Kc|s>e2#X1zTz5g2HH*pU@af_)QSf_y<&>-fQTT-0d;4Vy{Ce2Vt*T1g^M z`p;3v&uQ`>18$Z{sTStm_Ai$K4V&SlEw!U1Q0uS*H=P-Zb6UGV4V*1~5pD;4`9Yw{ zO+#39=mDWGKZN1nZZ>RB_U+|5yzb*Mzq; zRzoe|*?6p414$)fU)R=oPs3}EhgmEexAhV1=|8YQtO=hNRp+!5Ectr=Mf{eRNXj2P66k%YziA-xRppI!d-ip@$5I{_z*?ZVTBflbKe zF5PJ^Po}9(NFYP@(bhHK_#O@XghShBY@)^#Kl3Q4->BtLxfH*5Ml?DIM6Qp30)L`9@?#~`R~=ugkz#>RvuJU59(wAe z57PsT>XXLOKf0=%&%^RQ)2r19(;w2Yml?CX9LU(IB%6QEAZBK1=C%P0kHrd_y(gR* zmN!r#HO$YzA97yd5B0A%bK1i%mN(%Kf+GAO@6~qrp>6OkP1!&)xFRCz6Py@#v2FVk znlclPy2~{-E#+1)yQ=fHmKTgiMAPdx-qaibHi){&sB?58CMs6u5M2ewf7hIS+_Z*i zT3Fn$-WuS%{ksujq*%mnXgdZWkZOPHnqZDrd~zSz^;!Yo*;k)E#fss{F2`6P z%AmH-)3v-SBBRj%14m=n@DBRrl+Q4s8=|C|j{Sa!c2etUTRH6+`Q@LWPnJ+@sNpc} zU(!*Oujm~Ged4RedujPDV}$E@M9RzI0?>K?@_8r zbZOR1i)QS9v1%+;9XYG=xrD5cBiQFkpSW*r~t9m zW;xr!m<`^Rp6yOI%2y+<(AyxD+H09!Jwb~y_}D0NPY_MMC7f#TOIz&>x&~|)QHAU@ zH|Llt*qqzf&qBNS!#Q78_gU>3i>n4Pv8i<-o=RQ%;!{H`(KLD=0XiR+2Z6*kq*jdp z-Gm+Lhsy=+KK$lG+F3fV!yrbV_^Q|a!D7zv4*_;I@$Uhj_1Vt`F+5Vsy^n&$JZ2FJ zKe!8S;SagP2g45ycPs$ng~fXOVQY$}vxkX{b&5}1n=Pa%A|ofld1_1bUjw9OO}Em@ znHLh<8vl1QSTh@h`I;)-<5-8!#o9rg$b4IyhP6RIeqU|pd9AY}<+-ykRE}6(1libR z@$;~0<^M@Eg5%gYhZbYwvh}TzwsG|lG(7Eqf%y$d{Nyy(Ara)!$3@`$d_n^}GW2pO~WH=1&{|M95T=Ry5OEgz-EqoDDewj&EZ z&}U(_Rt|M>htIL*r)faZi}Fjzu^`JRW@f{6`A~+(#m3-f#OtU-hV!GwDR!i71d?(W&OG-ZB3Bk z`pkeqI{(ASwB=hZ_fr3+ab~%sbkKRQa5F5<0YACI8nDM*TmguVCTm^uB7EMHit87F z@BI4SzXF1&l=FaKY}DJXrvP&?+3Iq0V9kPRI*}bHeY)Td_1VykOM>CEO&;D zWnAb%lterK7&dN#-Nw)@|1)-L$a*9rGw1BKYvv8c~tx#*^;;!rSqedUXao+RxZSK<)7oV z^1XAwU}3OVn-T8~VR#*w|0w=nIHemxpCWVD1m)O#j!L0wo;ZPp1<;@;AfvX7Nh25@ zm8HSgz*vWtzXsubV#U{k-e&!{wBB_=BJbf5^+nAKL2A%QP$(MxXBds!f`1 zC@j}?iN;l3RX};q&nK#Wqjqco_Z6@6A0zdf1X4JOFOTGB1wVs9tN#-hKF z4|j9%U48aJG%xiz`9zN_K0Y<%{7pPvyGQ8HSL;<8d)Frrc~EvA6GBy7Zggb4S_0Wq zKwQ@Tv2=1swFafc!Sw95mbLxT|0{JiY+?yxm+CrGb5UCN*lZDg(cCdHHo4TWp&i8Y zp#1eu*wohZHEnODM^JOtVWhsfjXy{$9n)yr_ORRm&edIxr;Nbnven6GjqPMv2mM9; z*=aaHb8pCNXFO<|_uJp(4f*NiQNn2In*Y-ko-#mla9BSZ9+mUqyw+{8MJP7>DxnVE z<60VyCyX^2{D1mxn_25DD2Ai#+Hk=ge991f><3LN|5I->wdEJWorT?KB66c*G>yXk zC@>18cP_&u2pW4|tWl-kHZ1OFiPusV$YQO(6yTugSD(g%tn3AlFy^t6OX z>Z0ZW4`w6wDRTJ)NS2srvKPZuK7QM$P2Icr@-W{%RaI}aTns|}!IsYOzY1Ei;c}Io zmZMe+GF+dZl7YV;x4I!BympPe672?XzNfKNEAsmUkzcMM8m{Kt69j>|t7bYSsyuqQ zlm~Yq`E6eh4kK=V{d=_^^rmiuAf?Y^GI?oQtTI;f7fdbA5a}E-8=2(Y>Vqvise2$wJo)##LMQmwEbuAniKfj_EQa%&`V~iI>Kx90=e&v@ zw4HO;3VT5c9zn%D9;&9cBy*;JdkmR+UH_qU<2_A~8puLz7JNo(KJyP?vyoqm8tNka zjjH|wkek&$H|O*;cl@fle;f1`yJYd7B2OCy4cNFuJ=C~3h_{!<3tEV-f`opLRN^Sh z_@Ia-=BhmS!ViY%2LBt5?=FZ#bk_vMjG0Q;2^Iv*{5_cYW6gIi;sR(}akk4|Mx4C0 z$1NAYY*uAC6U!^x45zsEv!zz98@AMq6J-CfTRZ2cxPva{q*{e(I^{8Nh}Uq30*u-eQ7%;-{K9@6dy2J z6ZtZ@ty4y2 zVU^b23dzS84YXWXl2Hw>?_<;EeG}JdX=^BDQq08RN(UJ*499h~=lD|1bp?NvGP^e4 ztgG>B$<(?}1F4O(ewRKk&$mHS2!GuPuq`8j)$hW1xxCmbE4 zxtKxc*EO_V9Q}W}@B7-me%AUkCse?FRV`Gl<4&+X*B)AFouLtbHn#CWz~9kb)f_KY z98+X3)HK{cafy3qh9!fL0F0IccN==qG&KUr2s&n?7u%j8myA>Gu+gR2Gp%^HG!{D` zg4-V_6`wy_bWLjGad%s*nHGz<+wCiDzXrQ~S2e`bp_7V?1+hO27?856h`hc9+1IRm z$vC=_^ytz&$SEsRmZa(X3$y(X8nyA#_Q5(TxtepxJ|;VGWTUTw~68(DBK#o95*v5NziM+xj$32dD3#wXWPyJA9^B!CxrWA39pBu^|COJxXszbv64M4w%a^};XkK_yp3B0<(`wdr|0lMqx@`V~Zc{)ve=qgm zKUV!0-0Ek}P~Ni*D$kx-#ilA3MHf+&xC-ZL9v0@qJRj|#Xv{vYtrDlGRMTQiW&VYm z(dp9GwN_!cddVYU<>pz$FW){)(_X)Yo7(R*Q&e@fo%~YzL(=PpgyPa@(&Tu&EO*_k zDSZ5cpQ7n)PZO%ia2PI2uh&|9sJydFzaG}|YSw?Wy*9(bQI`x2JcR#f`txX*F4h}p zUATll6*}tZ)pax8H1qgP5#jNWW)RQ9Chzzki@ZqKU@*KIw^ohgfue^UMU%8*V=#v6 zXIRzAG`46$%Zq*gVNr>but|siUs;2$mzw(vsT$)IDm>QR z)(8~vR+LcKT_e8^&gOZp5`6Ylr=Z`uGUS>$RN+xSY0W zYp$m4-MAZ_;Uic z;~iPOTpW}UejFS>DS!14#HxPH9=;R)#Grc1$A(xHisJ#BG)-8B0xnOseX!BoRbVx{ zKD4^i<A1xa_VTp#+<&r6 znnwZ#`S)!*%10ge$lBx|yLiCBkUx13Zx!V!k}E9;=mU181&MCy-lczM+&2DC^WGXAfeEO zQveCeYKMLdgYwbGffYkxPsV{R`oMciuO0f6xh9^J-8NTvR2?zXsFt&D%)s7Ge?$+w z!^N>i`8e3eSW|)${n&c4>)nA)-?l&if9V0>($eifTg8`jplyK_3|;*0Rd_uBDF-v1te1eD)G-l^H!k@qX(k;pUKk$0Q%$a`@;(o-`I z-ir=xSMpx*Pr91-xrl^v7pZz00bGdsb~ip2yWjwqqH~hnZJ=RoAlHmiK)Hh*7GejA z5z@@8LA2k-0$SEqq`9~eXco0u0ygOl?EJ?@?A+cBI+s-jiesr_He=gb){fuV@y(pQ zPO}cKfA%q5+BfbcySE)xhkCdY5vCzRUN|Cb|9}YdEPFdp*>wH2OheD?TRMAI4KUbu za{27E(?fLvDeU2+yA|4u$g%9r5HafBqtK2Z0kYqb0G|y=z|C&P`)=9Cap{yu%ihdF zmYyuPo3ROp6)j(kqnvI?m(oRCR%syBjN*A2sV#aONTt%nySf?Ettr#O0SRX_zUMp& zl+B7$x|z24>>Y;h?QbIzDI(!@M*L2}k2Nat|ivHAvfF3Lt8TYE!i(F^Q+ z#~*nN#m+^1?3_V#&cx1XS=jkW@7`hk7EHEba9Gl-2ogqs{E7HO79f5HW8)+WGbY+?w_f1#K=GvMv-QS5 z6f>usa54e1ucZB+E~3Qn_TZp7R}jN!#87>Uz~JYX*xwtiTuyqrj777i%=Yu+QFP2q z>~Uyb%r^Z=6qD%hM#gaz#>e?EoUFpgn9iXW&-3HZsgx{4GGiXfu*k@`H;dct6d;M1 zkHo(l3K*n|a9q)EIIe#$7?<%Ig0#;?RH;Tdt#hz5A3IZ~W9QxHv9ndL-aaA*Lo$oC zeII+5f-`dub{5UT&cm>?Y$A4!NyETw(-BGL5FF73kz6>7NEC>KafBdmHg;Z2kT;beZ-_%N#hOL?(MtbTFe0NT z+LP+j04_R=;OYAi+?xRRFjCnQ(Wah+JrRvHBJ9ZtOB_^sim)fl$!Je>M&ej9F4_~W zF-X{&@{=#nKLYI~Al1!Px$k^qLZC>-}mcO17(4&%P}K;A15B>giFcyz?h zqF31Y=oaL0FLo|&#Li9Qk;gpj9M*=N-wwde71+5mGF_w)ahP!hZiy2YB;VO2^*NIE z$My7Sr4s$!`&ip&!D=+g#{ShM*grWEkxxM6iS>y5^*BH-N~A~3!7VnLqQfmDKsF-Cmu`TcQib@Rj`d5|ld*ah z1`dqdWMJ&DdY{dkwcQGRjLrli+1v6Y`g$zAi(#C|H*(PhN(7!5e|cuNO+9UdcWj2}lBXH3*G%V6i?P4ave`PWjrEb+`;z| zZhq+$L&CCgjgWX&4Py9i0poJE;<&OHlZtU)A;3G#lE>lppG%j`Yrhwq!tA@WWq zq$CxQ#3>NT8bl(ShDc&op=<|WXDFQsotI%}&Munq1s+AO;9ng~Gn|C|kFLZ)znb*! z@5QBvrN^oI`|NVh0;Ds>DX-S>%gXn>9kxE7?g2H zPB?YLCFs~J?+b*n`V$Kvi;0d+*lk3zo(bmc0)!fcP)YvedZzy|ygI`q_rlJF z%rjGt+AwIWBLizW19WVz>91nbsT?4Fpy^fnJw3UP!WZ<~wl*bc{W4!#a2RcAlbFkl z6SsvLcVjVPql>fk4H$GUZe4k_F6>aOCcYe&I;Uq_RyFLu_Vh^qFfjW)3$G7%1G8^b zP};jcc+2`VtBv>IZo;{$znx8?&Psk3ZA$_0KSJc--T5+>WmVHDfvla5{$X@4O3#?& z;chyhaW?mR_oo8a)L`Rofpk%H`xX7k6xcM|%Me^_A4!07dX7_pTl$CTiNMX99-iNXE7rYJh>GZhooB_XxsfqV)xyd9JF}mYq zl5u&u9ywz7VhR5HzPP`L+E+%kNaXf^Pkn6gT1I=#>AzHO^w(@V=5DjNsLcg8f&B09 zpKV^Mx8eKuY#0(4eObli8ogZlE#53+$h0Kr!6?7F6MB6A{{1HONORn+sA9^CzMScd zH+%EVz8QLmF4><43c0JY22KF-Sys)R%-xh?uD*_rj^}}?pS$;za<~UPt$XSYf!CFb zGAPLf9`;kl#^}1|P{-y;EOy^=%YovPNH9}LGyJ=VY61VgPyD-BK0w)jx5Pr~HRgxT z*uRkI^Fx284Es-T5VE?56dU{_qpAMkM2x0{`+7|B<=md5d)k*%0`qWc*iXNHusZu~ zJmW5tN{?pIjlz9tg^u3o9jESZU8)Ig{J5{rz*`@TyUR6HZX(!arXy zsh?vV-^81Bf_zXGPgf8B2Fn33eS@AMAATJS!p=)edv^_lz31ZPIZO3Ue%lk_?9F&S z_-#4N>(;8xLg+c({kgk~H~aFy_rXi`(#C#EAF-R@hOmX?HVV~jU zuB|?QA_peMrKg9&G9iml{`yz9t9pg1ij=>;!Uy<#D7^r_Ihx_+_@{3M*Vl9GMEws8@`I#Vy}Cm+ zOX8!!Yt^5#<$vlX8GGLxKFzHBye(h9dm%oO>oRa?I*X7?V2zzG|$iP>;FM*{Kyi4^4g!wW3!|1$;qz zG<=_Ucd`0IxqW=#`@1$~@a@#@J;|O;f3e=eXMvr9RG-Y9Uzh=3cYO?OkJE8d)e@C1 zv2(0o9GwW?*Y%l0eq5Ka@YXBnS)!_~%Ux+!HsQXxFED9ss34i#e7+FMU7G`wz6;b$ zt};F9bKiUs$>iBd;kJ%b!%TrmJ3o@imQUuud>gP#ua(sp{ z#fHyRV{)S&!digbJiK0|a|ReV+lxC5W>kY>dCVwx>vO))OUvluy-hl?Lj#GI= zHCvh|yjOk%IfJTytFMTChEta*E z51i*s!p}R)|1HS-*11zW^S*gHDMoIjnY0P;h2V#;SFv8wHy6ikVcx@HrFh zH&H*x%vo>L)Z!lGM0zCk0S#FKd`T?wfG=f<-aXf#TrhkCocJQ~xuuukF}}RQ&RAOn*B08FXH8a2jx&<_#Q|i1D+>N?jf3!U}`MsQ!%bK<|NT z@fnDlWe{yi4qmii4Gex62aiZp56<0bYPuM{K~56G-Dx(1c!s&I7AFEyWli0UgDl`f z$9#~pI;ijf@eH!w%xWSH7xj4MoiMLN3s`8?23yq&?LISL z(PmhcMkFVrpY~w}z(>jpVDNTvlX~zIcVKXFC=BjpALmI9CRz&BS+DAy`uBOFK|g$k z!B4Q_3RJ6He16;AFfaB17@T8tMm=~-8H08TR;7CI@!oONRrk%4aqz@wUwE(N{I}o% zzoE3bl?|U#pL&;Sh*H=vIgu2T{AojIJI`Ce(%4N?ZLF^w;?7yZPkyd*v z4K(?JYH3v--v-J$3}<^N{Wo9v8Z7m!V=+pl;-WRpFzqFiVcIjIhrzTXT`U6!XQ=F} zVzP)j)$FR?W%#&1<031k;NJ?*)oVVZc|cg1r-ORQtaqvQcACmECI!FVyDk==fp)5~ zrvpj*?v$-SJIfqszsfNtmx%b7b)MsW^O6FUVGo9AC+8q+` zU~f^4P<+N8syV;MVe6u$SfY6o{i;bGqdU3ma zRHsH9Ez#A+rBdnPkD3(vG(|XN1x#!hR2G&vz`&h`~W>qv-$pG?e_1~YB>uE#OEmD{0Y2+ttO4_e(ogAnW= zW0}6*jcYxWy+4<0{S&=4n)VxOI#F!d_YcHTQDYgsum@CAeGv%dQ7hytw=d9%fHKRx zHT{dK0fE~wrol@n;`0(|trm~l0Nre>MXY$x5WDn~5Iuzj;^MzG-<#8Ssi-tMWB#+pY~uVl5o0VL^&B+X#6>I?0h zZVN-8h7n-EN-l|#-nLNV2#I;eP2c`+ft>2L7J#irSlRY1Y69L1ih3=FWki<7dB3Ol zAU*?QI#w}T$Z|idn_K$|=$?Z>Agy96vb2sQ-z#!C)OJx*pFefWHhjjyE(T(&owF=S z1im--0fHHbptA-L=wkI{mhNoZxCJ_i>S8gOC!svu4uEblsImG(WC`VU{Sd^{b0CyT z-U~Ij;Ky!xo}T7A$bwIw*BpjAddp!QhzI^;Lsl;G8NZZsXo1dt^>e<&hj@m1Kl>Oz zo$O!n7zz#rV;P)*^8iX&)Tn5}5g;-INA$I^Be~QqyLIsOykNt*L50;=!KRqfy6(38 z-~ye4SUuv?J3Nrxsxz5*&C2hI<-Bat)Y(iKp9=eQvhrb-Ii zp)MlSX0ir~3IRVT5ssu{>2LsoR<%V4NF%o~Yc45G4*11@V zFSHmL3^4f1KvJBPTf9&Z=M4}iEuFrLtWPk&l|WT(&Y#`YsSz z=*|VPS2*#~7~u}1Ry$H8#cXCORQC%9Cd3YZsF`TL8P&Iad6Ep!DGOwOdlV2Cb=Z!* z3z%RTpHK}$F?r7B06b?s3c;aR4bPJEI2;hda1!2!&DxiE*v<#x-HfU5URiajj=d7# zs}Y_%{HYq=N6uU66PzSc!_Q{3j_x~b$3u8u<3Kh1M|L8>i#7oKiQyG$_%Jyy(Kk4$ zS`EL5%{qJFuw4YguQZmc;oI3c0MFV8@YE46)$qx3UJgtmX9P)qFq@SQlbDF`B4cZK zucY78u^QlI2p>4&jT&Ai=T-B9lcZ|+FgELY@?pC)gikgWtKp|P(vQO;+5+(M5g*m? zZ{$3>?vcZGvR$Nh3^1WPrX60OlZUF8%lH_iGOxT|y3L@v#e+e0|M48DR$WIE+GyF~ z1=0EBs%G30xk-*5wktytij27R)c8lkyd+e}XlbVFlCUloK zK*CKV;RfRiNrJSuZaa`rha`N{aZ;1uVnTDHq%4RoCQ0Dk62(c<58Kh>frNHrsuM{< zxMLcSP#py%aGktu$^LP-cxk%7L#6w8QVMP+^s)9S3v`|!36B}xOI1>+?xXt^NN7hA ze!BH931dxY<2gWrf<&2qOH?MwIBe&JB-oqEos4>x3>xa2l4hM~#NiHQ>4VJot1PS} z+{9;Pjy3xPD87&BUn<3y4!c|36G|hihQ7Ye#21Q$I)eh+IP3kD^SP$1hGtDaYZ0_# z8x#v08)zC>C}n{0y#8g?HX|Lf@>}v<**2(>#xw$(_%tk@w27QvlDxi6uyCEpg*Vr+ z*uo|>*l@{GQ00YQVIT_J?I^liOoD;9G$c+GU#8yHmN+sQH(;TvMK<3g?lizJLwKGS zWc`%zk|>iffG#e#Nj*NqF6bCiKBzN>(F$u$PpC*WGIp3GnR*pB`VX zhF|B%JZ*NwPDa8@7MaBDhzT|fL--J{91>o#-=q-WS*HNL$h1-ozs1of)p&id;fAHE zJxA6@_b>VE6-hi-6tSybC20qF_o{b;9drT_(q6n#i;&k}B3_!hK&OhFa^_vpn}sWq zWC$Pa8uNmLkGm&|hbeHyDX5|(kW)ahlvtbaRwS{;kyB9Y$SP!Q2sSK07NWc;3FH){ zn7jlQl92^Q!ry8Zo;xxhiUtH57LhDSEp>HF-$5bAaX^CSxR}3563Xt$_5cavkc5BO zK!Oyv*)a!A{(+L^Fg{VYa?U-DRj}bh9Q+S2c7jp3&(N~diI)4wiM(Rc2~<$d02Q2s z7N7#5INb!H$#%XfQWGWR;2V`JMW@osC*fb;h&5byc0U1q+7Qs2r;_n%)lD>sb=(QX zGD!|719hW!pI#8%3VIj!iBU%Coou*dF2K)2_=;LJykwtA;^=-*;+}+;4wUG{gTvE~ zdfUc$LBjJ!=*|at9>Tx$@=%X|(Bz0Sz`rKpId;0+0Y3UOj^ECxBjG6{B?|z4HNw9+ zh44gCzc-^Ok_9>)N-<0JjF3RZ!1M&rcxm?&9w2ac^WeT8C6iR$DY`dJ)~(91E<+K@ z5;|2Plq&v|K`p=*)mUTlO&&tIz6mgeIul|#-R8bjDf<1PN23?Ser@j#ijLDKlQ~Jk zfT%l2RQst3YL|~~QXxtRHtbwlC?o0tQcNVSZ%!@H`Hn*Q#`v-{i8pCPw{BU{6P}oV zqYrDhy;z-aa_5BWMn0?7Q0yFhje?RW6TxwBGWlj}ycZn$agy(EJeTEIGwht8c(Iw8 z;SArzcy{_q_D@uO0}vgWo%iAc(S^+EbkSoy>tsILu1oblCRYD5&792+;pi1olD!tP zzzf~G8B71U>rIuXQ_XyIeGd&=HS3ZdQTH!0E(t;>@4yTC#A5;d5JpmAUcb+Vppb1d z(EPygA+zqTn?(2Xq%3#a7$jx+3w%$78QHrrd$If`QLgF@9Ob*#%~pqI11-b$2)07_q*LFPUE3?Z*zquXn{uMKSnl=utk1wrHB zXdA9|L*cgc9`mi5dOR|h7IAJ{AEi%7hE{x(9t`~r*4ZTQI4WiIDpXWT8FqS`xssnb zGuKh?n!4JDA&wy@ZMe_&L(DjI6x8$^5+p5hkne=TeXskq8}v;p;^^P)@r|2SSt%tnfhk%t-3wcvURhvQji35-LleUYXaf93L_i=Ny+;&&+k2+GIoK z1zP$tz86ltsy7nGGaGJ`<6}*oE1N?mCOBv10j92cCy+F@xveVfE~1X0K^}LvGH@36 zuMUDzc}2h_R~E;x?Gsi;=4Fo1&NxGgb{^y=?$hJdxe4ojCd9>@0+)b1414yZn7J)XmFA zuE#yGg)(;u_s!ek#}${U2|nXV;S`^v3P3p5vD*ZlL1glT8c(~!o_|#@aFi;0%{e9K z7XRhIO~$E8pz~%wOnw{Y)r2Im1_c92%Jf6x>L!24rbNJ$-AYcJLQYxwDY>6>kn>dE zLFAN;?U#6P^gJk@hU_l2NxbJEKYZW(=shU?cEIpgTJY?rM*x4=l0bGPZ{G4=B>s%m zq9w2-A_2edk<)7YoU`IUsFk-`P3>9vT9D;xM84AI^ps7%oQ1s0=50{Z&;R1jWNOy; zH|iGP{U79gO~Jpsr_f?u@0(}bBQx45w7Zk>83cP_hYVOyCXTN6D=7sM13(1TgF-UU>bBytK3W1i5Q$qJ;}Uf%Q1cC` zSvP91dM1w1CM~Ia>Of5;2dG}fd(%tsz`Pg{7VTgUzN)t#7&Q#nkxwBi>}?!1HEb~C z$X?Jk*p%FJpdN!cvkC84kZR(3G&$7ls>!C4gSl6dVC_UA+@_%FdtqL`N_#o}nLc=n zdBsl45dD{DCif4~UvtJ;Z!`7oL-Qt_m-mLJVJLU!3ieXRSRgX$pTTl@FZi{tB+M&* z@1efaTu*>}RT+AUNsDvZm73VIB(#=A^&MbNwT@B zJ-xYkS#r`S~b{`4*c{UKHtlvy8 zt==iD1WF(4MI;FyT@rx=(NiEHa8Z$( zgs^8kk*-;ao&6!zD$+e<8t2BmvLJdLlCajSh$KN2D)9#r)*%V^U;RtMVbj<#Y*31D zl7!@EqI5}*6gv)*!0uC{Cc)FS8c2|p0twcOOVlLDp7Dxxd!^XbAcH8sN?8SXbK#(Bu^ZMNQmNW?h3O1%pBJlC zZEPi_7vwNe)t*lZJM^2gs+@D+$BN;}%Zn`|%%CGxPcF^PoR_99@Gg07TFuX~PRs_E zFs{!g$UiCbAZ0FPSGae8OAH5_dl7w%sk2rtk?B4c^#4@T?Q-!Ez{davLLJCy#^5+K@seCKTJ+R^?FH-$r~sezno-Y zi7TsE3K5$xk}%ENO-;f#(=Z^R7)kKydr*~ylrF7eK}0q$SQW0dnzld&U^xtKvs~}s zc~RBr=;=mpT}dkJYbftQkXMk3q*Xu#5hc{IW=aB+WI%#4T>Hlf>Yy(Qh6v5id*$JO zBQ)MN-K(H_#UnuVme7Jp?brR=Gnp2gkOC}4X@j3dL#+p1th#k zsCyUSqf>DFBW58aJZHP)I>3KK_^?4MfPNxE3vt5;EmbKa$-`KT(25U((UA4EhGenM zkseS^ohDFqr|7b~wZcTa$%ToiK75Rvh}s&A(6*2hF`yenXtEVB5z#ARj=5*nsvAE&uTcC5Lud0YU<5{(7lpu){0yuR- zVjy=Fk%HygS?@Dq(r+0R%os+`O|yA_N#|;Lzrt=)-$S8xR*HD1%T*0s;HZ3aNj@fghcni673)RA6U$h%V9g$W9aE_X6LdTy?f~JoVfI0i{crSU<(~< zRmkS%4QE86(9k|mg`Mh1d7T_ciC&=d`?bmNIz4n9z7|7qc6n-V{f%Mb=#XnQRqsbg zp)&jHSn47SznJd%#t?qpGM&*heE_4ebWbDtc}h?T<`cq~#^JiStV zh8$@DN3k15E5}I#7<*VVIfkt~pO$Q>+?2EwW2=|HSuyiwvxy3ler1BoH+ z4+riGlCs^uaE|iQp}I71KDi=Zjix5pQfks#hXjJF0oisftoeaa#bfS42+TRvEZ#&4VP#RVA zs5#yfXi$i|j)0Q(7y~ieoBQT*4>|hhw(Gp{IiV5&uSp>ufC9HdfwRY){#W3?@JyT` z)Hv{v6!>@pgFuPgf%_K2wm0tI3^#pyE|)Zz3XOXgF!Wi^Hjng&^73H9nRunOg4JR1(o z-!Fz}_6m>}E-|JJXR6i)g|^YMC&G0ss4GH9QnCzeNMKGf1|fgt;6vo%#Vp8S|cgb0__z<$~h$z{Hq&M7IM(E3@ZSle$};T*Fr8MEzE#>Xx5K9%}Ff zF?JFdKPLmmFicde7{bfdmu&XL9Ut2mJGuQfZ`q6sW-N#Nfg{jgMobk#UMm+>r1*C1?9QCxD!wm zt5HZT8QpvU#Ro*uV4F^&5N|dZf+&U|3Q0$v8int3IV*k{g!>anD$ZCkdQ1TnO^D*7 zZ61k2zQte!qF^Hmug)Sh3h8rsSSp}kt5N)F$zc8jD4G$)-?l|03Tl+WXhh*m$l_~> z8b#A{x$GjK;Hgnuwq#gN1r)7_qQ$m^M8Vx^;D#v15VH7MtwzBqmosQ1;g(6=BQo#7 zH;7@%0Bhm|nC!I^biR>>IxeDkKWK5&e$?VOkG`q3_?o3jFnd2N@bBaTuiL;mN}Udb zbRr>NY~PcF$afe7AtB2MaeQxA6H;9+Pp7%!br+Hlaj1cwMKuu8goHev5DnOs>8A9E zup}TPF$oC~`O?+*8EY*~!qR~dF+EHrOvRof z;HoNdU&-~gco^K$QV|lqbU+A&)mG z9C!q#xI3VT>1H^VMB)368iC@7KorGZ9%>X#5xz&{lK@2CofJn?1n&rS9`Opr(cRF4 z#3Fsi-Gt)Ugjguvfod!$%0UKrE3UnQU!@bFkR(`~Q zxMN7|NAhxE6Guvg8_QBek-RrdjpEvXVnk6!pjdSpP;kX?S1dejfT++x4)DZk-2RRW zmzXag@aer$?_SlxtX?6reFQZ1GW$r~igMkG?8vdKQNKaO!qaS&wdJodjj&oJ3{KGwZKb)~Ep{HGL;u)muwSqq?4IeN5~$l`$FA z>~C+K{v?X@&0&Zlkw8)O_+J#K2C!cB1#dN+M8ST?4LSx6D+f_nPcBfSDBLVV6x9TZ zlD(E{D$)ixy+sroNfe@Y;!}ts3{lLQ9BoNbQMQ?J3a!p5WTbl506>AM>Ee5vDKMvz zs_9S7GYEuj5`o2&JqJLJxv4MQK->%mQ7RKiy=mau4=j&+Kip%jKFw0?{gyWj;1Xi> zWEI+)&0&Bx3eDD*fh~^@uasT6W$MY8=9;*s=@gg!XVc_U14+xAxTOh6W1J-{^K5${ zP3iL9u*?xh0h*Ies$&bV_ZIJ0u&1(q0s^T&dD=fP&wFxPE}vI?T%1&VoQf^{pF8*Q zarw_MaZjs{QOG=GSxE1lI(3j6Dj!?VLA(IGM1 zPZEad57!bUaWwAi-GlUGE2A0rT@qKaC|6Cz4AqN8ry#JCKHyD~C~qdFPhO_Wz;U>E zVX@p~krO@EG>+9zbluOX_@tkc$QV!ba{CsoWWZHJqdCtg#Qnt1VBRS%hM@wV+0JxG z^UU!)BVFm_t9eCHbWEN=gA>J5I(om+c}bLma0Xd(fIX*QnwxZvx$c}_t0(I^aHmOz z*-X=QuojK_!zsu7p_6Qs4pWdLSz?MQzk_p!*aw)#z=`x3IJfk%(>)k#VHAcx1s?Pp zVn1-Y?ud!gbuW1oD?}>{{Z>I1WZsV1&{cPV-839-=|PU?!5O}BMlYOlhV?ud!yCCW z-Kfqf(b!FpLtR_Sg5<)&dk3RWDmakBnBOn0I%eJ|zg0=q2>W;Mi91Xdt&&F98qTs= zv?_^`1TSA6Eb>Z#p#J5-Xgn19R*EEJOlR@IrDYN&vya#A>@#ujinb(~T_x4~o6S#LEGFdp@ekUZcVGiP! zm~T11yTTMM&!rJivLSF%uCo?0)aS%TGj8aI?JH0TK*sy7n_>!ipQFT@I9fD>JnBgN zDr3HMPPVf@M<%JCb?@f61n%dp!y(q17o9)aghRX`hPXc^Rh=_3>lx_+^BCoVYx@_^ zbL2d-+9YA7{lZ|0NUELJN*!U(PZ) zYfvH`Zy}sH>?~*4PyoOePMWvoOIn!-A&+DN1R#|pSt5c$GHE;-kU&0 zadmB@-Ox11WP=I_C`~kus4W4(iO`4!g9s{)s0_{ng9gV0>;^QLNWf?ei5OdhvnD8@ zsHok5iV;m4P%vn~IANSHPN*3FXIFI18WDnz3~RIE04(W00i_G`UQ0!L?-No53K>y<4WDOT@eK58UXXQB+ERZKP z8Dd3kS6`))f2WeSWrv_-*3hmW<~p4|@ul?s6?--qK5@;{x0=lNC+#s9JYV=IU^3>M zqP1;TV+YGA={gnWx+Bz3SaRRSWOUCC_C;!*zVZDIC8%w`>fiHrcIrmk(7%$5-Je() z$wDK}C?~|7(Ot03slZnG;CVLQ+Dd(6fm%_Q^3f6_`kxb)J;7KO&c+4t(ZOT zX>lAqYLQ~rgdV1@_*DP#f>n6UciZrG$r_lNdqJ{zFO z&JbbpoDWCY`5$lJ1EaHh`Vv|Jb3)Z>`}d{d3)t)Rwb-@6FSsC9+x;)6hP5NtZB&ix zm|`q;mx@GpI|(@nktO!l+cp+Mh+NG;@zt?j%u}O9-!U!YNYnDKyC^sw=IV*wLI$hL zh3OUEiQ6TDBue8hyh)9|KcVB8fc7^l?=Xkz5vduYz*PnMuG!pQl{@Dn4?(n3gum;U zb-`hKkXLC*bevxvq-Xw&-=D`^l;QWIR!&IM^;AmF2RR+02V?&-N7j3f zgfCk)1|+uAUV8Ow)%L^b6<)DvGL|!RNjP#S`zaWXxrzj2S6+ak4>+ClyHW0$P`!_; zj?S<^4S6x|n2Hh)7zt7^P?8%Yev7WB-}9nX$aYluEan?}$RqX)nq7T34b6_O*%Xbu zC)1pcq0j*_7a(u-vsrDC5zCNHq-B!xiq|oJ(=EsH>DjwWWAW^XtqZ~a+k!ta<&D`7 zVGq&TB|%$f3Q67KL9f@4+QFwilA+~4?~6m_;ST-iaO4S<`t*v05sPe?ju18`{)79tyjoP5KicVi zWDr@NSc6ZFG6wSPAsW84nk#LTloQJT~+N zX2LrB{?%r8G-E}~C3^5!G*W_XdA#N;{P0IuFJwS_p0>F+9kx`v(c!@6ZgdE8I!=e^ z@SZ}JJCOfIRUc2IVxE59deTJ#>*X=ukf=*z3+V7N=1&?JUa?v_TNb;OIYeKkgLh0e zb4&CUx@CDnuCZXU#*`PZ<^C14Hg9x|hp;K&{fw$Pp&n+5EIbhDtR+6M{34&1lP#1CJE&B70ZR^^LZCn_w6$)MZcT(xM$;rFo1^Xpsh zi=6-a%yUb=s`f_TdH%HZ1M&aF?JrJLKQ=~7V~<%;AH<%g9w~^{)1e{yEL~g?T}@}p zW6n}Hy#moG{T7pE7)eJuet>WGX5Jnz9PS<;$s7tym_y!b=5TBtI^aY(rCvjdrUW!^ z$sgfo1;Zu}Ij!in8n2P9)}tfoRZFAGs4+*Qe`(9n<5zt8xZHHGmS^7lcm99MYxu6- zKwFB$_lXI5$-;l*<7}4DuF5O60I%ZltH@T9+@9IKKwlcsu(n{SJ$AV009LB#N!avAEX@)FOAh0L)? zDHgMN6}mNly3>8U%}-;Gv_gg5D(N`xlf0i-5$`P ze8gXLxW5!x>^y>RU_3I@K){3qxnOZc`jcI583IPqR{_y?m-)21 zni}+Y6Y{#n$8LVB<1;UrApP0(n=cx0eXG-L^zPf}8l)(D*|Yuc=a<@+!sNc|3oz>V!o!nI+pOgm|y2x!mFzHQrfXp z^A$rtHDAMI8}cbamf-^7yN6hZ>*7KS^r-wf_^!C+uOe28_G@bOiU;{WH63<%UE% zKrj7Aws;#+KUcN(AKxYF@V~qRUh#Vp&sT_?+(nMyILeg z-%@=yu+`PnguBa<1q0%8LAvv8h~7M72_Pgm$KQ@@1;r>kNa8VxkWN0Trr{FC6X8(N zX=WU-VEsk(*6A&1vhiZG?ueMmxD^Fyg(_A=Z?{PfuCbe^Kq}4OWy#1`W8EKlC^t)! zgI~_Z)n_rUaJ6!PRsRN$*fg7@;P5@dX(p6;TNgKb_EYdk5DQr`qN-osYc&%hpJZ>npdXnmQbpA-OaAqf$edNO zUOo!^_Lm|BqX~U&XLiGo*75ntH_u+R8gPZ_wQnPdq@M3y_LH|C$SQnX-VOG(HTGd< zou=$hU)St>SL;hZf*E_&o4ka+NONH zvV_a6@JAXFwAFNB5~eiGeusm?yd<`-QiIBUq7Rw#iIxh2t8 z*Qn!?rac?p3dAuPvMUB9J`w7;<*bf-GJLAzzd>z3f2)I(8UioUG(|A}U^=SGEplgC zzvWHd@ZN50haC-f`3$`-{Jgi{{nk;45H9;9RljV#>Kse2C`kB1EKp9^FZ^t4uzW1N zYCz1nwrvgaR}4W)sQnyoe*hAS^y%eoFkFqT@k!Vz%wM)4y*B^!0E4|J{&BheBwvB3 zlLserCb73jTv)AT%vWQ}PlRutlWffNKzc>|!}1Zo^VObubNF<#Y7-M0C6v%$E0fNB z=H2|D0bHAYdz_e1XCxFh2;=nb#DE+@OD^Z7S3OVE3GoPQB?A-kjsB~wT#X6C1Xu-T zxzfm1vn^(&S1}jNgA5q6+9#n#xOkLf@rZVLC;Y6S%uKtvtCw(&Vix~HFP9-78P*Fv z_dAoXr($5j5rehB_sK_TWMN=zQwcR!j%sGo?2+{5*+_GSxL%;OAF3IU8Yib)qF0+S zueNedtKoSfJ?piwnV7SKo>saEzrf2#w0`X92VB~jvE_$@T^9HXzyPu(Cc5ktwHVnX z-={I_WahMfy(|dc+$AN6)|a7=Dv!hoU!Bk5B%#7R4yhi~YF@S;)A+WQ#^_8fxoOnp4XHNiBH=qZW`G9Ho zTHdwNP#$+d()CtD3CY2}+PpBC87lb?D_{ImZ%DioaG~5Mp+bnaq32+n<^SN)ONbX{ zmIE&|UTUR=GkRr}ckH{izZ$g3iXjgx&jz)UF*U8)w2$xwUd89|Ccl;C?QnWk{3;uP zK8OXxbh4c3IyQ^f%)2ktZGEYG#!~l;JMgqyY7X8OiQp5hGfy)h;7Xr_e=v=;!0fnX z<*L_=>fAIc@IBQPrVa35)B6N$edEF8X^?^AFl@~?l&9f)WLZ#;Mj8f1F!D5{_D2SV z`YNPYc+sDJ{Yr`y3-a;^q*!?6j1&tu&KTo{)D?ML3}c2bB)C)Z@> zXh04d_YrYk-|yiDhF@pms^0h`G4{q0!*#|N#Bl95hUkDhnOm5{Gu^wrUDxCgnK%wS zL~aTY=&L+HcOV$fM;i zbV)=YE?VgWVr8|j0V%QfN=G4}p9G}Ef`oU&uP zKmPu`-Ey!(uQMLGarL~cU1yR_%6uJ8ub2DEd$l3Yd#IQwcFETPNrb2Rlq3fdt z@@&yMXH)b{+)XNBo-^uIt)qJD;;5e;8J6l|rYK^IRy9g-^V}M`xsD-#{)}#Jh@qRW z5$&=sLoGd>x%n@;xySX0f6JqwjzJy;j)u67VU%9Z6u6M-0Z;l{G@A<75WBP}g9=yu z%=Lh1I@<#|Ke0U!!q7?Ic4YXnwVu=i8T+W&^F4^oS+kw#0UIhLgJE_q{ebTQc+sJq z5e&ccE~0Gq^`LrdXOU1xE%8Y0CZ;HNB4r*$H%~f8J2&k& z{psdGbhE?Rf9-)+QO$aw>U0F!N8*e|8jmrX8mH)yi=cf9W#Fi-OzI|*>b%Y0Q`+4=Y3G3? z4dBjhbEwDDd`%mb(tBM}K*n&_pH1ddegexgW5#@$)RpDDzLGz0(u);SO{I}?>5K>| zD5}9vS&(0AWE9yWwHQxDI^IPyw(tFYieAd{Wh1f5#bS@&B*U)^A3tel^whfv5unsm z(6gYoo|_?^pCdi=FC|n>i4&v2;DRsRZb%t2enFoCeDdahDr{$tG4R*-%D0<}iYP(I zp&kZe?j=(!`jX(toiMrm&PAPjeBEuY?E#Z$-}WSqay(DreP#8m-_I=hgB8bDIxB+2 znWvF1kWI0OQqRkfF#SJla@OZozl*GX(;g$neDm{pY4;1;S{C_Qo>5u;qP4u}Yjs6y z^@qy(FOl`1zSeiN){nGp{}HwQ+qdltRomC1_I}FtBh>9b*0mok?l4x_VS>8D6kUhu z;yk+&qjGs0CJV*qhrH7r_lq5~)k~B}j?!V%Jg3SPE)FSbhjrqP8`T}N#7=o?r$TY3 z3bAvE+PO^K>6+T5QQY~Ky7L3E%QJD0l%yW(jCyWN>X~KKD?6#zexu%CL@X={@3TCa z8A9#vPR<iD4>YS{E;E4H zp(x&Ga;b}@O4}C(CpMk=AR8S9mI}_%E!rRp|(oB>{yxt9?FlB{; zzgycK{M~hJupr}Kq#XY@CjRdGGUD(47}uKky9X_)Gw^qxAqwOFb<*tF{WjhA8lnGl z@zIx{Y+n0q{|9L{sjJeD{o$be?u_@tG@I=H>vC%t+Gp@$KRD?47&O}F?$3CWCIX?d z+$>3>7$GVu^ai7J-?#r9A231$rs>ahEKf8@!B3t13H;yQmF{4YUU!OUrPXiu<~O9- zG=Q3!DOb_f8r0fmNewJbcM_PvW!v`-_M0Jj@zYPsJx0irqXNc)uX^Z@+dW74{`g*- zs2P%DgT7iZV1zdJ*1Z_qHuwi0@INQXQ&%zt)$v57J;#$nG2><~Hw4n?zX^3Lv$ZsM zdVaH!PT+!KV^#z&R|Gy8A2iqP5TCX7vJ``@c|UkES>jkn8RsADew$p3oyoPOiLm)z ziL@9+@r2FyDx?ypR)oz#7vL%y0x#9(gw2DkN^vJo*!;-9YaIOD5#Z8{uzAR8Q}_XQ zMxD(Fo1c1keA0}txvbp{J$#}m{Gf@jIc2;H9C22WD{~WJb8Wf&Iy}L7x7*;76w_P* zVROWU?eJjE@Por`6JfJHxHuno@`TMF+Bmof2%F=*Z+{~Tcq?`YFjI@Xhl6&Oy^y9^jIXscDcEYgJXn8c(a_xka%gi~tV}U$JX_CqLKR>uK6t%yi+GFbZ z+QC9zJphEvT+(-d6JObU?XS|-$*Dz_B%)VitZCqEa5&p6Q{!rjq6J+-e<+ykPqPT@Hv@PJeH z;D~LsV|$~0+o*l3lor&*Z9Au3dc^KK6E~ zR%O12uXba34&MA;uB4FBl71LgRIMf{!yG|qA#e3QwrNsX8 zw(YP$L@e+W#mi4fZh&$5Cu+iec^=mU@w$U(LM=6+-aH3QU|(ZarK+UFO$i=k!@2$X zQaVSUm-IHv;rcve#v)W+MwOQqHIXYn@bvIRSOOS^*~UB=?md9cF{oTEigLHl$C}P1cbIdCT_C^3i~XG%99VKkh@E3m$cz1;QoH)Jo3bA#_5=G%Yw4WC z?rT=aiJdWf2gJ@I8uEssete(H2c#$`gS@-g&~kCAO;iME&1uh9OZO3SI`(f(#yDIQFv)!!w(kjjud0xe4 zjIUfQu)|?f<*5lOIg)SR9=652^^WZ~;;@A|650d~fUsB5pm5zAgrG0y*pEwu9+4JCY6cm<^`<6x@BxlJZ(YPd{^5oA;kt{T!1Nfp6)AXI_(sHs`wOn5xC|Fu*|>GG z;7WFlu;iIr#|iGKaLS`*tr{Z`cwCUnH-cM3E0@g@^z!8w^CkzbGd`$pDDExW4VmY8}=$6 zx4oJYM;g@dPx~P6mdGelUb=uvQ%pw{wFgkewJWG%<^mLZAU^_z0Mesm6;ogq2}j|Y zpHaAWBDH59ZmzqGn;*=_%^yMa>G5oR3>#Bql-;qbg?={7aN6Z zt#q;M8-0rY_jIwmS8M8_1Qp!teH`jn$L;+dBD>AJguiofasQy+pCh47&Kbo;B(#CL zRPQw+=j{AK=w0vsoc_?u2x=HKuO!OQ^Yl-sY6Z{UJpP?g!<>;^iS#S45fT5qKDKm* z7uKejWq82%Em~xATP@)`UOG0H#f#Ob~a{QAZVa~IYKX16? zpqzWYSHCnH)vda(E+aKk@r>oDY{}%ms~ptt-re7yW>e9){QGMTx}bge*N`^~MRle= zcUH%?@zdMl&U=(5ShRC5MtF zYo=s%g_0#BCF_Z~(Z4BKhZN0~tRN^=`esVjSIjw7jnX8=zbIM5`cySjvXU#%@*`Z! zBf?8=G3U^8^5s>Wo+EBjvQ!pFc|GSrWgL_&bze?Td0?-GlBMoTV)r#W%8A`f$$DVa zT*-=?)jv?ri!C1!Kby`W_CX6HXl|^MQ}Eq<^!a}3^ON;WVz0G|%wRM!gSr?JmEy1) ziKe_!c}hRCQ(WbWIX^?gDLq|;` z741%p0t_MMnLEZZQmY!uoTGRd6q-UMjnkInjx4LW!#}8?nNzYpH-|Jb` zJDcfQ<4K)rsb`&Ure}Q~B2+Yv2(MhdE4*^uevII$8y^X;+?G-d-G!P-0%5YZZV+4v zEw^u_JUt&HRG#t^CJGgsgsq!|Mcu4yRn$?$hG`cG73lnQ=W~LG#+Y9VE#Z@%%0UFm};C$kFW_$^e5 zCCS=wi!46v@%yKs2`R$ath94~HaQnb8;g<^OlfCGe<~!EsG2mVg`_#<+qq$nmG&r; zrdU|29E{X?D4~{gsXEf7>PeS+LAumRdhqy^IA*VQY9DVj&%EdWZmu6^@wKn_IwOZv z(zH|qp+ZHFrX?axtBzEtI#QuxNQKIyG_;eTX{mbSHeb@T;z-ky(`_kqTO(zzt-v4K zam5=yAwiK?MP;HJnSYUTkw^hsPCLmQM$hVktF82`hVIr#HRBrX=TI3z&bSY#zy~O5wMO(O;^rZ*FmB(W`meBr? z99&=5hfK?r6dN`%o~B%A+Ej|ODz`eRF~L5R9j1Q?&mi7dKH!=x9#*#y2aX=*&9e)( z&TJj}hwVT{=8Sb@j2X;cW259!_rMHH#3iosVykmR4>!y~rBqzY*{F0F<&`3CB8uxZ zi}t68;UxJGzgFA`bw=nl$3i9>O>Bz_l)laEd|Yw~_+_p~NM~VDWY?>bPzv9UC8K^C z&+4WbcLVDL$5kaB)UhqLOVA?!F5bmfSFs-!U+lP>OV1dg8W&F+G{lPH3z`+y9$G(o zIH=$QIFsVznUjDph-M{2X62=#{jGRoXFkPaME$r=QN}YWV-#Y@4K~$a6VTm-a%H@j z84pviC(791Yvq-t>_I)#Nzb>V**?~!iCh`_;L{F}rOL2yDnINv6AXtD=#R)Ym2t?# z>xVwAQ+7u8vAZPE-vXTlG->=jxV-aF1H{Hw^|vpU1kcD1eTacEVpg8PEqkpHPbI)Lo$Ny zgVO_#RtiI-&~AQ1sebG)L=QuWKvUHUympO-o`92s;Z-W{j`Q2&Cj%mUTcn$yH8&ta zHSIw8bqtu9q4btvTzlgGE!OmS+slI4F`edUBO)GJC96QKidtlf(whm2sY8ef%2=9 z_gLM0wGq;!js_R1`kKxYSuQRxVyhivd%#{X-!Ncrm|TNFaC|Gtp4y< zPqE?tY4uUMxXkhULR`Fc;KUV|CbNuo@aUUIkWBfPw>^+y`Piv1HOR1>x!q;+?+)rI z#eEBr@cP)6q(W?~Us)A|49mKXMfWi=7S@0N1?JIvH9Ou?qi0Ry_SQ7Pm*a-T5oM;7R`h@avNliqyWPiMKm~GoB|E|RQgdr_? zuPgV(_F0Y3K~1w@aS>nPEe{W{>*+bHpV98dM(l&rj`^2bZtg2>n)@#Db6b|Z zxgGP2(Ip%4(dzNtc^7ePrE%jHW46-MY~??e>#5q&#iwy+IS0`P+f2e{|0oD)(`M{d z^@#T+i*>r`E?}%&tTE&u&~kHtZqh3VFEXbux+(Rs&sAl|3NWg9a@`iZEJC^zyL&{Z z)|CBtt|w+JWlx%dnzA3i^VqHI+${U?JOBM^I!AMwWj_wvhN(wG)i%q1JklywkF9?` z`|(JdxLP_#7c^x*mQOr@E-0j3`KIj08Dot*rC`^a&wiZUevIjMFi8rsA5S}nYJF+` zf~yuk;}uk7nC<{)B)<#ZcX*w;C$t`6_T$f%U{hb&mjrH>{a7(B*(JNES@z?y_D@Xr zf-eBU0O|`ISb!|DL zJOyozB^@dHJ5W~`Gl*MZh+>F$6l|X$aHVn)*TiszwieNiSA7->BWtnC;w9W=x#6?8 zo3OM;UA`6GcF$+Ao8UIYXhJYwoZs?Yq^W?`h4E}$39FZ~`yN{IKweCn6|n@l=iv`MR;2H-BK5$EAz)aXO+04R_v6a zbbcVdWO_#{z2nQTTv*mplZu(G1~c;+lKyKi&c@!vm&@-jn<3fvD1A3ptOqOhodMqN z`;X!gma+UVlM1~?_^w&<>l+~Q*cSsHrP*kEuiE;DgE;#2&<5PGC+P7j0|;kb$;_s_ih8f!y`1;5|nrh=2hMej>0U%pAhvgVOzObhg9W8g%mKK@`qplH<)ca z1`7=9SaE9{;Mo}i3AXr^1K@3ctS%`bp(-kZG=&J!xP_Q13TT~h=5sy8jh?Si-#~(m5o}ka0dy8^`y+g!&_Oay#Gw)^nr2T=f7eF1L1RgZ+!Jh+IaUql`G3bXj)D`OrnPTd9fIu73NZ?E2lg@hc!vhlHVuE7igQ-%gI)8dT zcdo~J>4cA|#L;XaC%ioD5ro&$dRd@z;(Ko1bBmw!$kH9ht+yEyYU|A(0W6RY*bQls z6Y4iIOH$hNS;fUMm#}n3v_(J2Oc&#GX1%3K$GDf-mT`RaYUiFMZaVTm^LI8|9CBC2 zWi3<3-ow*cdfL}?37%$qvErM)mcKT0B&)dGSEk1~Va!fuOYkPRp>xZ`7Tp+tZY=Ip z%XeeW#Nr=8zRSHs-EP~xUr#3I1MPSAP+LVX(FVe zcc!!thcVZMKS76(*P6eJSt4FVV=_p0&9}*}Ez0SjO0f%I^I@gH&#SvC>AC5WcA@Np z!0?lT@#a4yWA=6hzkwi*u^BcX&=)Rd>^D#_UiQ12xF;$|>0$O`$k6qtxvz9>8hpG$ z$j2+hd?(-I4V&BE1TacsGVLi3KKM(*;rdb?R1kM)Vh~Gt=;l$ zPFWGB-#`;J_Gn^``F_5MW#78@$dzCxgnbycxP!=+MorwGq&5G>RBh>1?>`9~c0bjd z1EAB+$?iYR{d-$_rH+^kC&C<=g;KW~^&J8UEVc-;L>4357@K+~*5#eizN7MkYqU=( z;@1#k|J)AFN}KZlJPz*xY)aUwW3*q(e3shbRjDUyvoyr+Z(ua)n}?@DxsJI640eAu zROx4lA6#x89!9b8ITW9NnDp$}N;0%B2L|(uzD-Cbe-9y<>X(d9lEIbgR+n9@dk`0} zes)+3Zy}$tVTpv|68#0}3C>Nk7ld>LLmf2v=3#;P=~lI9_(^a73fY&0mI=oW z$F~p=opnUnLcn9`Q_WiR(i5M>ZOWVX65h<%7LS1J#dKVY zE1-VB@M3&{;KCvblMes;u;u%%EATgARs4^zYDT`vX_1+3FOaE<0*2z|SXla+NsBJp zwTBgX!C|dP`kt(aJb{W7Gbmh1VjmVF(aJSst{y>NLH`iEZ6tV`fI@0dpb!mt`Sc*m zi;7V2(goyc6e9b4z5L7H6zT9Zno0OKmGJNWD@h3H+m^>vry=XisVOqCiH%t?1B6o- zJ-ck=CrV;Gep(usOOt_oEg=PqcNaj&b`sW4XC$t*!n>MYoqs`$0fSn_E6PlyzTMj1 z6t#_!S=S?crrx@t7T~t+Yy4MXx>zKgDKcH?D_xXh>hF?i?38KDKua^SK$|7W&Dqce zd5?8AC0z$?Mx;aEKM&NT+3Y{|;dU6EbM78+gp19whT;N%;EKNHe*@xnyfXALAnwVl z)dhgKKQ|^IY9a65%Uax#J>bM$wqP^XXGhLD<7{Ux&a_}JzQdgW?Z2^il3nQ26Uk4! z0Ci`jPbpj#qSfu&D{T&rmpL#G6fqApI5BPhn-KF+pk?_##yl?WF`SGhE#oEDw{I3N zv7QN@z_P={OQ>M)G(KpmY%^EpnJvcpKr8dN00T-jd}US2QTYI!^<=yu8@HM+X6RnG z$12p{u#V;UNI~~atB|WhhXN{s5`H>dt8m5=vxLA!0H1HNpq;bj1vfR9`m%QPB46)Z z$uF3G<$hxq$8qGN%HZ3mHHV;Xj)Gf&#Y3&yVPBS8-u3O2h&iByHf}~7)Rj-5Wl%+M z%$v{AkP%!_UuNseEo!eqZS=4LVYT?Go@;B?Wx}~TsIBXHzWIg?E|v*UAvo~ZQ0^j2 z#5uvvuCRZzR^^4OVWUx#RoQqv4?DYf%D2Hj!9BUI{rXhkF*0y(4B=pB$#^U(lAozg zdpk}jO@)3Vkk$ah@}wG@G6R+igid@k)o6eLTtrMf1R?b#WK-CLer)na+{uMaC;$nMKIAvLE7TT`%wu|op#ba@mCBj<;Zu;z4dohegt3ia-^C+(>B=A6ij(g+T5@ z>SqM#@nTQ5TA)W`7jTV_m>_X#Kqd)C8pw`j^Er^~8&;3rE+vqgqih1X3p}+-Z+OFV z<>{u_ERqBC#t;lyD9_`CkGTxt;|QD9OXGOqwc`gC77am$AMSba)bEVl^oT1Ud|TpV z2Jp-(t3!bv*^zcCzb)XbX|w?g+KuM$h))zUST46q6z8Y_tTXj<3FTqFF5vq8dQ%PGl?8SImhYgM@u>Bu4z?0RRS@j`s8KEA zqbeF#Q&h!h>c1ZAcsQvDGj>N5141$kh%n6Gk2yyVY?rU&;G{D3IV!KG%9|tJ+AI33 z5;t8>?)$B^yP=_T8bm$Kc5;I#A}R{{Kn6uMX#9Hdn$ii+fQbHq8ouQxhL4J*bBL}s z_v^Lsw(VFyq(s?=xA8>X*=(=kn$}LC#6T<>$d8JChwU&b!sxyBE`U_*sE8Q-{uoH{ zbJKaZNah?;9gr9J?HSIaY)?C&Y~56PDQ+cvw`iyLK4^A(o{I*Tp* z+^&s|d+<gre9@&leUFenAg!o_wy4L&B|ZFz0AZt{g5y z>{_GCd<)1fqbP<5iz9f!jK)(5jYkaE9>WQYB;?nV(#Pyb$V^NXGsRrp^JgQ7K|G&& zqHi5vcTV`~DNq-Vb9Lv0pP7PBnK7gL{4ag*5U*| z%CaXp%td$baU9v14hvxzqBssoKeH3ui;kO00Vf}$d-j2vH_yZ*ktS|1x=q|;;SR)pa?h;lj z?teKkI?1#o^vw&a=2X~M+Q9Mjscl+l1J8Y;Zj&M>WuiU8m}rlfFGzVQ zJO>?R+RMX?7g8AGg_PK%rJe&_Xu5z(7M63L#eG}e1_jBh2(G@jmP&F^my1XE#9q6B>(TC{CW7p}2$Fo5-o|ooCMBqzf7}cuKgjB1S z_E$L7s;!w3dh$uT%5l`06lbv$VJ6RlwvK*j2rBu|rbE&AMvj8DBL#pgOOUnbPbgZG zg^JeMmR)-+J*$F??R;*|nOH zf06!G%Z6c)ey5$n&{34LyW_jJxB2?}&ljnNkIX3UVHRmNVu9L7vPR+TuSU$htEhIq z*j1z)9R#(JJcDH98I(aTKZWG-6H1;zyC}F~EX82-z)rcgEpF48;x-5L2&=QrWcDcb zF?$q)aqa(cuUrmoCfnnBC6uk8uiw4H*1`Yw{x_l-5T<~A_Qt0T9ibl0t9u!XTI%Dz zJaz+GTyF02H97~~wciS@@nhHDU%KJPjrEh?z-!3&)R0E_6)8^~ng4eE*{KntnSU#+ z>uaSqL~ZNkZ6B-J#za^@)6joEk();NO5?P^v8JE;K2s0OQ4fTy`HQ%H_xj-9VT$DT z-C0a_D%XnqiVV8b<2V2WT7A@V7?qS>W$v;$GE1*@%86y}?dKR)J+VV(C?yh{L9q3V z4QCKl1xI2dx|G|SXmCI(A{DKMuAIx!)sZ3dWZYV&4BQ4aN<+$9Mh_0Mi3fhC3qdL8 z4cpn`fm_5>>UqGWu1Q=FsC-;BU~eI98YqIewKHS1@z=PM%TlQvWV04}r8gEFo!JF$ z5f3_tI(5tBI~^})v9;?CcBzapD0t+v;7Q@(>HrE0mot`4w@Y`(p&4zP)8 zy}d4)xDR#aQJqcO>v3a~JEk<>Ue9X32m~3z?RDie9jcWJx7Twja-=P`*WHHawAfx> zVo|`uy-#OR+`Mp$T|TY^&QqEE`x@23A90@C{yN19$0hg6OM(Z7bRJS{WCM7eZ9BKt zeh!aOrksFSjU*PkcTVqAf8x9%P-{@+^d~G}?xrHQNsAoY59DY7rcJ}kW=Fstemi5B zQvt7|h$b9SR6#StwgGF{S9RzsJ3>3^DC?3miMA6Mi$g)vc1E0xS5cIcty24z*_Lv$ zDJR>UW6m)Yv}&0lNF6!IZj+ihgrKXlp08XpVc1QmiA%WAp_!2KlsU%;{6wzj)I4#_ z1*mx;q~=AKF5^*S;br++pjR08*_Ib}=$}UC^vQNIgVx3FChEf)(0O&FV^a7!_m#wZ zPQ3#aidU}UhezC{bBMj(G|8}IAa+Qd_#uY>I~x9<>6^r!Zk08NQLA)hv(YH0+gxYP zA$Gl5&x?KceduK3ktDW@ibu`06Nfnzuu<-8%NshlN*|M!h033p8X@|f?em-_wXoSn zf^M$WY&jd~ZqTZxlyoeEQF^ih`AIx6P6&v#1EZ~-~AD|54ZI5RR!O<1x zZa4ykO`>qxEDEQcN?@%gfjiqUxmc}`ftl{(BDc}TL*J3vyKzbgLTI&w*RBv=yMv3W zdNQlWl3D#QV4a+-=5qnr`48@%mOEEo59A){WAZAWb{kxTqOO6A!i-3Fy@lq^U<=MxAjjL7H-im6dZd zkoLsi8_-Jt(mvclzAG**LD~zCfcR!0t*TvrKraDE%eYJ+ZCgu__MklOMl+C>6D--( z3Z!LB$cB)aEkT-Stnt}qAWa-#GYuq0Eqx%>VLO19a$15k-{AWZ|1meCxf?FL8#?h+9WxzVP^FeJp&FQD2GAxJ9(NZXA8#+$qQ z4bM@+M9=#`?yyMUhiBq90BJfQNDFBR(v(51L0WMStq`QCY|CXP{|iVf{Pts1?nae218M4cRSddonuOO|ql0cfL0HmqiT7b0D{|cmq$Za@h9^I;YFtlsQ(9Zg8 zu&zLjzsX_0hTjG|S=R@az{Jn{ZIIJ!7UuYG{aaw2{Pb&D#xAH;7@RFqS#ONJWH=-i znS|>7__M*4Nd?A*jcrJ;T8yhlmAK(R*b@AC{2O>~T-dx6f1bGKXa<8Ad0FJGXDoc1 z2>L~DCOhNv1a|>40oG%(@ojOr*b3{Frp0r3r`;H}$K`k47I)wd`>OQdB!&efp!}5g z8YZ05XR(j4TsTt13xLrRpQxXtr$H%2gK#xo3JtWH!uK{RH~mx9I^ zf*lW36QcB7Wi2o&jw7R@jo>0g0}>Nl1QPREY%jP7rwE_Lj)IHh=wc_q#SL^(F1T2Z zi&5*z=$JQ|jEtFlRRLs;lJ{sz@%0UK9mk-_%}wS(rz-l4U&;}TmBL`fkWN;~@n~Z>{l=(&E9j$#AxB{4 zEGSJy=pF)TdU&tuP1hdG zm7IpBp6nJts6q1$B^QcM2s1U=01JPu1-#>IfVwxbSl$L`VyfgGeNfQ}kFKCc*Ye&M zrQ`fevG>DyU&w<$?u4zdmWID(-oj0WzqsIwov?|<8UA!bhc1V?s3oRUPF@Y%qae6N zhCfzISGF051&Yg_9^!{^LjjZwgty279lpUs;}OTQsciX0AFvI|B|NPno)9UgU1S) zIlIMiW97%^z;;S+68I%=<{UKK>ept`HFcrn#vMUc;veccGpELGAS%2=6-LT=M+RN+ zy`l_Qq(DT;np*0|9rvvR5_5o&$=|^&iZ?P%xd;5RpYThDvWcK_rlc~!Jn9IdF_OfnanxFo-g0e=~&{}8-Q&Zf^E&?YRWFh?Z9B>t%7A8)O+Zh#C9@^ z;l!3re+#i2No=i>_$gV}*5#gY58=Vd+t!v3kB5h{oNQ}iQ$J_!Gs^syU|Uy$HJ|u4 z+1A3XB4t2MO@MP^ibE*|I{6ltIGe?DmFGGm4E!b|>PBC_VhP`Zf5wo_v96KA+f8Zt0xm<9<*L@oV;k`-tRU+O=Sh;X@ zO7DUah5?g(zGPECVG*byQ<^+3R*PW$J?GE$gD7eU`k{gPA+@N9N%Po7Oe*laU@@I0 z&>Y*>WMFhPFX|Ecv4q=;Xy!cunri+Po&W>v2Mi%5`$|vbTSz=`jsnJ*C-uae;6C#& zw~;Gw6DcO!>~)p(a0TZ5h!xW%dCf>Mm)PBub;wB03W|GhRdV(R_VATqV!l>26+N?r zdS;)NH%TgobV8VSZ_b2J(r$4xh{}Hc`r8w^KGDxufj)UaJ|P7~M6=0EeYdNiP#VZm z^tT<5T$>y&ZkQL*6c_0VlGP+P$ufi*CT{bzV;7A>7&UDUM(F%rnRMVLnRyf73!$5b zH{aY3Lljk#aly8Nqi7MQc1Rfz7(x~Fw=n9{e%?%e7i;!Dv@MU^7^;f;@IBsbrzIw1 zo@OPp%CCK{q<{wPVfi&blSujVb0(p3`;ZWkJz5yQ2*fq)D~KPwG*ljN;Y^_aC8-nl zT69| zk+}@sY57YtH;o5Xer~yvlhnjQ`7^)UN(9~yQ92g6757sOLr=q5BI|sk(G5(=Gt`bH zZo|A&hlLz+MSYmYh$gdpO0p~T-B0O9cn z!yw0G7ow65V*6alU24K@dFHTtBfkAr4bs0^}nVYRlu;lQb9kC-E z;513JnUXrS_I8L#pfJ!~_hhutH(&+1#G zo%QbwJ*<||!$#uj1t;>5$WLg`GJdcP^ND&$Hy8*t*N_TyVKeJ)&iaRc!)R~==!X+d zi%{%Kn`w-Liy@=)xGm_Rs?SI%O8mHn3#V`45q_(Sz{1{*wXn0k5?Q3S99~Uo%gw{& zD{){m<#S&ebBlgBqGOyWCJ-Ozh3cQ^hgdkDN9gclO+&!9I04Jd@XEsl7$c#}ShgV^`ar=$iWORAvcKfBC{YVi~j=onN| zZYVKontO|(RDzCjhK}28C?P*rL+Ps)**?c1>3sUNOD%6ZhTHluc6hG*@@b*X-vZo> zO2^OQbY*)>1&Kzj+9TxC?6y2&6h?mD3Mc2a@J$L1_iv&9cdB4c`Rs>g_-@GF&-!}! zAVW^^v;Z(2lFeU<-=KEI9fjJJb1lh8mb8s>b!cKO%ze$)tbo?$dz2KkN=D(zg>yal zV#Y?3^F_`a#<_kl<6N&L9qQE&bact%f7g2+dA z%>rXi){j4ZS>vD!zd83+noVi$q#sehhOEeQ?+ED zr!5=9?s(9RqBv+#rTqDABQ>1^Q9BfI3%+8?Dd-X@q)T-DnA2J`Jye(oa)H-+aCPc}tuEqDZ%!roaQKi~pm18{yLc2^g?c00tZn#nOvtjbr*;@9scXf66+fCQnD8pV=bu+0MF1uFDFBL)B7~sx_=DuU^J_ zYLeZXA{h!hthYYzrlS~$_43QxctZZ#u-A~Up5zM|&&d}u@quDL)CJCyCS>A6-yAx} z(@oI~>nh0!!IgkaQ#3=}WmO@xA}*RiGtu7&T29Mo2G#Vg08yEw8Z<>S)K47t0q*3Y z8Fb+`*3g&OBoJIQgYwfIP>ym)Ickb#5YMyv4R`X<4B5j+HDD70-PbId;qTkz^1!G? z&7v7HrWEI+PUy?5nv^j`a)vpFoXw&cnt7o&k7g*iqJ0Zd5p}>S7X@~g)X5d zY$k_Vwc*~7wvMFD)HNyLxri5rJzK*MjfFN*zOZ-CCDb@up2aDI&AiZ&{)mUnwO1mf z-TMdT@Jf6E;)Ophv)xY$VYXT$?U1eJ5NB4^B)q6u^r9B(c`JfC)@q>^W+kZ223SwsTF%C)++ zMLKIQiAR4_yEwvBcbQeG44;QG?KBiL1;uJAOyha6a_oB{6NHk_1d%hbS9|m+uVe>C zKx54cF(oTmj!_tLIV@9|>Rw6CZVHB{x{7{`GfnMO5iA*dSJZ^sXD2{{Mv{Pawqc$0 z-uOco{gl&1bAnCKpQI3zr%kAy|3uJH!vk6 zlJK9&pD8@YjHTn~N8I-jvEXlXp`LWYYes!}%{h96)m%kj z247C}1eV=%btp52P8mV=m z8PDP*-^dDNU&sZG8(;X62~LveGTTOKuF#z3?Sxl-jJFK$@1f0T`$}76rBSzKIIl3j%x(J1`cJR{8Ut(Y= zY%wtIX?QJJx!0<4IVQPcZlLV7aacj;5IZ}Wv}acAmjk*$>@*U4a2_vq?R?{9opU9R zII(Nz`+v@yL+l6gJYMWmHz2mA3&gIsuII%r`&}lz4_kIyPLkXasp3yD0aE3K$s$~5 zD|f9#<+^lKo;;+CuUxetxntp3H}O!e+iD#bN$X+#Y75C-+ZGMsB+tysgyb_x^5+qt zOE937pRE=sTyiBZxyI`kNSyR%-E>3ggC1)3B1G*}>ZZgvDqE9)ar8>nPI=EPnjc5! z_ngBx>IRuKZQt-RZ467aY=DaL)d%2R{ z{rPkMO6DA5_jKLO^|`ou+b`WW3_krKpryStZGoQ}jg zT+gY?Y!>NJo)_`uc)C{kn|y-swF0gb*~$|rrIAXxl`@(y#oNAHFQaa`lEGXnRG&m0 zdN0yFFo&dA-!5r1Cq?ED1(4zvNipP7lN7&N_4C8?KHT&APbBL$o^w+N)8Gj<8_Lz6 zv+^$L&kjWWjgm=x{ZU?ev+rTOK5cma>M9%0uTfqkN%5hpiX*)cKR9?EYeZkXLb=@x z{uO3j&*-wyJylxkVxd-Zb0hWEVdo4O7j|AAiT1)YT zC!vePJh4-@d4;hZi;ig%6UqD@9^)Dwy$HONnto%{C+j+3M^gr~uZ=Sott?vQ_5qlZ zMxoQS!>96{zROFs&OK0W&vn|ak5bC$972oj{ias9x^r(WVwF1z&{F|rCe!!=4t(TZ zTv}#Z!M!!*{ckLl?%0!|xAuvvo_lM$f3BAy==lQ(9bvzW7rOKhRaTjscskeXrJdJn z=$r=rG_z$~<)S3@=TXYH|U%spKJCb zSGjMp_!=r-OqK75Oc8I7Vta{6D>nrFAe-pZM1$s8}|u%)!YK|dx%%k#%9$C)?=XZ4HP z6#>WH!izoI`=z}03_-)txWpeV{IIm&vc8X{{kYBvTH4RtBvBu!8VvRS5lu3&4I>Qj zt5aubAxOfU*pZ>X5vd=XD>+5&I^o*D^_H@S^aSK;B)L={Er(n*hV^Ts5S*QQ*07JW z-2oE)Nl(_grkssSwt>=hQ`uR!!}{z++lF7_jQ{=snyVrMd9oNU*j;j4c!V&pk!e?_3nhZwr5&iDACyFXHqLE+|| zxmja@gY0%LHto)0=xaXD=)5E|nCpqUsE7}t<^Mw~=E@x4b9T9weJ*}&dG~`rJzvHm zhwQd$pIpg9lIp%|jzLtFuAlS`jlN(SJ!e1V-}6!j%hJ~vJ!iS)qLb9W2+!Bk^N(B$ zx#z_LrE(99z7RZL^7&`{^Pf6cUU&mqEADw+XK9M@hq;oc^!yXoQto->AnE&fK8v2$ zf0V#KuW_*atr02+_q@1=w8S6Jzo6%zxmIw`s|QPa;rSeT{+ekb|NIUI%gVPHRowIc zgS|J8i?VG0hOgNN7+_c(HoX8BM9fi9al;uF#eEdqw*e7Xa5OZ_)LBqbQP9Ynl58|n zEGG_@H(rBZJcR7}co$t5)@F*VKJ?|EHw4NKiUzvp@1&*%B$z4??gb6w}Y9OrQ! z%lA-s77N!P|8tW6X|RpUZyg}+hWwQz|J~C~`9D#}>R*8?!jK+l0v(c-fS?H+6%6%zSq-Q1!B|SlkoFE! zn7b)ThKS^I5qT`ExDc} zq%TvHlGozEND`r8LAEZWle45htca^I0J{=42;0O4F7HukN4%oM-Tg|=15to9%pB6i z(J=E&y>lhJrW~EStJoC-faqdUYEgDE&fl)+rc*@Bjw+p?i+POlTU$ss&qB9F3+@qe z(~Z1*%~+D~3;A_oZ1+kRdr!VE6txM|7&besjJ19ZXn}1AS|AVY;b=kqOl9_Fs6Wqe zLt`B%u2jHp$(n}%74b2cif~M9#YS5KF2Z0M`{L#KnYvK>)20WG`PXbQ zbkDjF!OP_aK0j3_-x)dF5{mV+BeZ-|QhZ`REL_s0#)JYii9L*$_*GuOVLTk&wS^9H zbhkLw5_8u1tf!vCZR8CXcfC~%jF30|9TS)oJj;W6FYg4roDy7#%#8dG&|P#cLGr)kIjgkpj0#Ga$R}1j^sgu^LTg8ChGJYBUj1qu$uU zs-fH&l@1>ZYayFRJ!Bi843x+j=+Vhom{jcr#=Xj``?so2OFZ-G(YV`@cW;U_L$+HB zQI+rm`gMRXk5`T!V;ykM<1buj-gX%N_YrpT!FViP+=&|x-W+jkbXOSOaN|L}&)YE? z+|YJ1o*zix7<)R}!ALF_cj*PAh$IZnt##WS^lZ0ZV~u4rWQSZC+V>5EJpPCrUQC*} zxs!1_*YWDvR3bV)0v#88*ocm^M-t{rudj_Te8|n|zBE<);9kRRB)_#&44$(CcWbr= z`K=?7zv4r^Ljgf;vsUNTMyQLp{FXFTREJ)}<>AQxOD8^_v-zLRHX}c8H1eB@avb?J z+pP4C+6dlWE`M#BYDDK=!&M}IaCo^Rf6g2rOS-Bjl8M_};K*;@X4M_2jj%fM-%V3# zyY?EcBl*XK8}Xd2-#kY(9{rz0`v1u;NB)LwR*Mz=|AedmWtu9pd#~YUlHU*>gXe7i zk~w+EZ=Fo~|4FeUziPWxE{^OqygHma4{rI%Q#6TS2noIUG^ChQ!x(bJ8JI!~iTY6X z&8LnjyyquN(<6N`H-&lId5gq-dkv?$pb6K*%Rf!Z6LVWj{}FE*meNwPln#V+;g*th zxmXZTjIFef>&CI!suDH$)H4L$s&YNug(E}*{ci!By?6pPld4OpV>4CHrX?u|tx)%k zowMMr5WlX3`J-V!S>gMorHgmyiZKC`HevwwKNzMx`u-RhkCZn_%$Tva|S_#qOQWv>~!bW&U(+QR<AjDbionaT85#(&+xBf~5po$8A(9|{3p0eHxs7GHf$drueP;m!Bcx^KJIE{saCKw) z31`#jn_E96SRnWC_3CZ%G{Xdc$;WFI`c~%i*N^2YzA)Nqe)|O0AS0Q$t4q>LknJf6 zs*{;MWhLxGjuPO&X?G5Xx9;fm=#)>KVQW&hIUETdF9pO!<_gBm@=3-$>CjuDZ^Q5< z{XH0VTt5DBM{l@{mKX%mm6k!6KBj1t}cxs(2z6sQmu+geUY&6p}lb{SG?r4n3g1f zgnUG-R7!hwf%(eSHN<>n#S{Yd(L3WXd2#e>V!m?c6ccl3+aB^Fy#|AD^~-e(21uth zt%mptCU9=+M(nyDy||KhrTL0L)fZ}_QwR>@m;9EKdpSDgrtd7QH~&S#$g3wvu9c@-6$?Dg$A^}o*yBLJ0sKhDZ|LB) zL-B24g7bLmn|Sg4w(EFtc;yCO>#-m8qUvjLYI9ti#E~zgxi6}wiM6d-)%Q{8p01h} z%Q|}T&9wn}EgqFV3_bR9=xx**IqKRfRZDA`FLMhQqGQeNY|3&Qj^)P|ArmS1=*=|u6 zK9`S)jUbC?*5BCioMS0eh%H!w&gVRdKlWKb>qG-%boZVZRnT& z8_Ky~`U@)9f7Le|`(Wj^9jKsR#u_PJ9O)W@7x%^M@Z!Xj4&=qcZRExA_yBryZatt= z$EWM#q8Go&o^dVhFY6x^Y*wo^_0IkmRU<(cTCV9g{9Dtkd&hDMKex18Ac#_ZLm&mQ@r)F%`vukFsk`py0D-z|*0o@akgUfw^Ii{_m^W7z6Cd0$xX{1Jf&< zvPM6BGRaQ3?Y4sN*<{ z@lbmFc|*Yk8{)*R99_?5WgMb*Vt1b^BK>4$JDYxQ)+3+Y$z zB6Dd;JC9>;&&X-9{O)0L+qCNNlAlB0eiYhgQHvIWOhqrg!K)^-6S2w5hmdB(sc(kK zmbeinmQ#~kQ<*Hx(KR-o-Dx-c85EZJMr}x?6}cWGRen`Z6C#K z#;`>hFW%pJWNWZMzHhxi`aBAq-A>c8%Rh&%easyj zb}+GSn^ zzxtsOU0NamllUKY-*pb z$ZtJJiJ!fw`mxpW``|@OyH$cw_L~EvY$WXGU^8mFsD;Vtm;Mp3<{WQ={b<>|V%Ur( zxx(sH`H2-asV^fZ!iIIb7>Ug_B=JI$ST)il8n2}I$=+2v+Ml@*SsfZ*o3xvxuOjI! zo5U`gb|89*R%%@7M&zD zey$*W_aPH7D`6SpJ*lJcO>Q0LXy7W8WP-Ch=&gjBjaNj1nZx{TG#v(HR_j>K{$?NWj`R< z*o6j>l-FG$S{&wdFv_(t!K2q^@*G^7oO%P>eCav9NNg{H#d`dU-CL=Rxs5 zM;6;9WYH36yNRX9QuH0N#E~p3h)pvIiK#CnT{ug+kcMOm;{N)K;EFF^VDOQ_c9c zX1_U`7Q3#1Qsk-j!U(c|GUUmdd^lpAVP zquY;6ng!{XYs~KVF7r2qBJ#)P>5EzMqC@_SRe*5wGKS%$ED(a?fZ_Qn@SeCE606i6uydgP+< zY!q>C<`2zZJ-1M@9)r2h<=gM(SMJJZ zuX(dj2pGdpioX4(xOB2y9EX-@0~2;euq~N-XUxUf(8WlP;Ou zAml5(PpQc?jhsZ*>KSXTg{A>El{0_JpwC$&>9PbRzeIv3=?Eh>z}}K|wFpc) zaSqhZk?JMB@3qOY3Ecv^hsotaQ%0fD)1=yDEA}w)d1~d|Go6cq^Ct24d}IE|C2B#X zS*4Q8`YmyO5D=%GE6*}t(&jl!_nPxUxpzmGInR@(ZngLo*_4Clq&aOjIG=O%w!YUv zFAk3ysLJYO($O-E+o7rSvqOjite@*EG<0txeU)AWcc9BHLR6-Hj@Xae78`7rDaPl* z8TnkK2%SqD;a3qy_%A2t1rAnICi&f3qfnZaLc4|TAybNk_)DfKDg&SYov;6Q^F`Kr z^ZBycVX=f`35ddTVEkhGz7ZK`v=frMLpQ;qf#kx%Okcn}h{byS!~DrjW2m*Q=8wXr z!Cx{>QKRtbB^=J|H#dB5z269|X9$A}u*|PoZOtG4++BWJN(Ev6xib3@?jJ9k?ZHR8 zum_2?h+_|a-)~^k9=tDf1iggr3uY~G9|xQDA>2bfsGef_<$Gqet{upvqAz zsm_2Up-7(?(A!CJ;qA$KS=`OIoK_)X_L2}W0SgW4skt|6E={JNl|1y7u5@L`RD}F+ zIZ_x8P1aCWnh9Yl#xwGlOjFc!e0m8-GUF$~_m<-qMDPt>#ftcRMBZ@Fsk@OFBVqyV z7^9%R!A)ayE8jt8gS%7&jqRhR?6GEQY2Y^!8+bH z$CY(SJ;PaGM302Lnh2x`egTo@n9J_tI< z#sOJc%)7oqpu1)rUp#jOTk`OYR)qyP*J4!bm&aoqEehdifzW9h>pEU=1QS2e9pk`_ z`C5wHCbF|4ozYhpT9wc30C0xLsyHGmeP^SuG}@J+MPw`qEy|o{ zsE8QmNtaWBammc+=iwwF6pkH#AFMqu;XDRIqYa+59Dg>_Nck)~hK4#Yl=@r0BDgGJ zof{ed89Koz2*Nu5^UpzK_F~lJ-xRWAxO3+5EzBQLtUKoK`~D-F=I>FQV~zi1BI24R z;xbk}HxbDQc?UIxJZ8)UxX7+mA+8-1EE)0(Io@JPv!d{+*yg1AY@YhYnh->vpie`OdfYfCL`7cxV)3HP@rI z1rwo+!h^sLI>d~|gUqGWxUgdmY>wiKol)CfQd_ftbD_QDZveZ=;lmLc@4bxt#%Lhm zoX%=(8v$~&kBcz7%lZfAAzNTZsWksV>A$rid?+N<0V1!_pEhS~5#3MnXR($lS>tzO%FiY9H&kynMe{X`0h1 zn|xFF-Znq;Hyxj5tD711317Wlddp}UxGbv9xlTP_`Wt`K2OSbj2G3Cu^2UU?V3{zo ziVl>LV_H1jtjg~h9M|*t;d^alz#eK59qR}oeN$Gd`H@?kBC#jQ_B=#YAc?zl%j zARtZcDj8vkKJt%~`lKoqqtx(I5upcMCyU~vWX-d$5F-b4bEZec2whv4lkF?lReXdkHj)5mhyW--iNkt<)~HErqYA$$&MC?;>|HEq zofr$60MDoiv`gk2F=q+fG;>aUKNQMvvTH+CZ369Fw7wbaB0ZAJwPNg9PmMj-lXXym zBUe`Y5s=IUsilKu&O)J;l4T5YTNA+VCpAr4^Nln6(7dg4##WoZox$FHbdqtt*%%sb z_0N$0L>J7eF1G!wpSd<*4f+0swk-IHK#6ZhwPic11j$A~vlHf%ZtU@oken7m5f6&DgH)BEL9^Q|C_anNJps(;$>6r^@3qC&-hiq!yoME& zp^{*Q1moL6r9}Nf1>(FXzk| z3Y)NHmTM*Id$GY-+x?|5G}fWHCYu{t)%LDhS+6w{e?j7`_O2xsu0mFoq|+_lYDsMn z1I%4f?hd!tc3Cl+U@Lr8RPI2xAmMfD>DpMBI!Vl+@sZt#R~1lP)as{l0?yIl0Iakl zcww?kU3ojRp`BTcy|}AfCpd}hg3j47n1d(Cu5>O!_M|=}`z}3ks3qG=&s+69oGBDKA||ScQZ_ugB?7(?l{{N1G1NwfrNgns$}4t>r#-auN3=7*&hwlxu@Yx~ zIE%v4X_qnZuJyf>*716;w<@e#HqDY(& z8*`@+K!;xGR|ud)=@Wlkh@ObD7G6F)lt4(_c_WpRP}>!o&$EyF%T^x0y3^2olAEn& z&0Ot>t%7Hzi~JSu{79WhH;LD`r(C+fG|ZT?(-7ihiIJ_cUA;8NoWL)ZSG0TfG@h{u z;59)2?w-NDYy7&kUT_Et{Xvg}B55h%axWg?F5c(c$X8Cgx<@GCREa|XOng=qcd}jI zq@FL|dE7@R_qfpL<#V1uDx8|HeAav6Z=$SKm!EAW=j)$6dflD?ee`|p{SZJ}#bW!Y7y0sqzLD13#R-TaG-8=k{Ue2c~edRIK zJ_3t;f!y$~m(s!FkK~&O4zho*zW*+d)U+Q5F_Cx-J!I#^HiIX)6quE6DuuNDm>=fx zf`$4~U6p*jB8x#T2#4H8<%Z)@MNp5;_#%hE#{!X#1+8Iu^9VlS%?~i=XCR*8{L}~< zyVrV1E!ITI=<}>3hq+GX6I{F4UnJhIR36?jegyLkSa$G+XV`L;1h{vr2q3^+J0L(d zUu*6R^ljYQik=A9QWP*QJ*vD+xP-{-?guC?KpIP5p%S}F`yFej`)DgSh!B18I;mZkZHcVQoz z&h06~v1^xKefU$F(`N?fO+~`_Wx)+5?N)V#Q6SnBl%X)<`?tpHYj$~C@a;an#dv2; z5@)!qmsL?|PGtt?Q9hB$+d9xDRz9Y`^@(4GG=b-u=69@H+1d*Gfxls|xt7kGWbAHl z(%sK?U0*}_RBN~LWaV_2vrpA0w6K$|&d!cR`6LAj|9YUR7k#d3<482fXJUjbM5&a< zXx>sLh_aG2RSrLz4I|FemW=apn%vr~~2!I~;g#U#%wRYBVO6#Rz8a zdbUIuHa8#?DDGH+-XU<}Q-zR3VErhJHHH zeR+jAYp7753j^*!t?ANGHo&n3pv=T6JVPll<0=DKHqD}HY0%rt$*_poPpkcRaAw0p z^sU$pSHPeKwI-fd;d)XD3weZTe&U|R?az89=0Hz!>NSscH-ASQB2wSwZ2k=MKh8qf z0UXGpAvQyvXG*G`7_&Mg4wn*{%@om+WxuJ!5ZWjknqgOJXF6@>8!pY1Z%)cz&!0}AgqDh6Rbg4c8sR#t-J80^d)q} zd8a)8LR-zSp^(K)QIC2+4oliu5VY3Qxoz08RXAxqcGwKYrx?`FmG#uaN?iY(_sVN0 z#=*3<1lYbh=a0!{@HG#(=<>sbm?tP!nv=*ZO#?0fuqpYeP(Bj^uN5YY{E-oG+Jc{EQ_BJsZ()G7hwl+2Y!5;fx#Wb z>e^{S#&);Oavs?Jl5*AQ>&3>1ea48Qa+PEV_cg+RV9bGS!{z%itZx>nO3In=%zaPPgPA-UWMbSqx085`?+}|3CsTM`pZ_uXjkGh?0!2z)CzCJRMJ0bz|&H5c!~u zU~I5}QA%Gz^2~+RkZfOF1or+mh| zyE-2EzfAX)$$u;;W1id@oWZ0%`W}<|RYsBcRBB)S=&Sn~zU1 zAc;=@m}5;AZu^L1O)d{nGreeHF{Wj4ck&H0>qys<9_@*3#RC!a-GVf=eBviuIoFoL zh;YfYZGgJE_GHEA^agtbuvKwN`H$YKZfVXoj6Qcwu8K>j8J!e&6Rb>-Q*((ku!qQf z9O7n%-}r5TS|Bv3yw@7qZ`d*$k-6mAJE9JXgtU;sTJnchk=odm&Y~tB5M{}!-w0~b z?*v+RdDnz&YnVP%uA9&6G_qC@KjLQ6B_FC>6}Wmc{f%+fcQd-`jO9^zLb*6*lH&Me zjVy+$y~q#a8BdMz9ZVE|^23vB-PKz2%76;0NwS|2|Z=kpSeKiC2VOVE&S6it31OUcxv=ri(5&yY~-R(%{{6 zXTReOGC|;4h#Zi$Cbtb)I|jT1X9~SCEAUi=J=_#)s7rKOoNV;aB)PPYyoI$zM~I&z zk%Sp%(kSR088HPb@nrd`sd(-s9L=PIXdEy$e}w zch#+Yi!&%fAgC9u;4T*G{y|_Qh?A=Tf=WA@#;8SSJiWPNW6w++Q^e6y#%E>rF+9=7 z)qD^ucML=zgpjlviZ2N2heFg3gh!ZPccO=F;cDPT4+SwLVv7-8Tn$y5+MgvrJnX^r zMNi053D%A>rsr@uk>PuykG_+Yao~S}&i1}Yt=%h%y>(eUSqb~sgP;ai@35oDMOm3Z zGlRRjT-p zD5tUI;pD_19G6nq%Nv+LNMh6rXB{I0=!^lB!*#DZ0t2W~sC!`mp?Ey4iVyW(zggWK zj_JbdD%3$LLg$-}$1XA+zfB_Ju^p5T3y17eMt6a3F#Sv=yYBE2&onGjBXyHzH%<0o5?k9(e~d9~>HqP%d#+mzL8$(y?lX zFlWh+;5>K?SlCVUi$xfT80ReL%;?4Kzh^nek>6;`dFeT>BTXSYsCbUKQ;%?!LKJ~O z-)zqA@$L5lAs7-tVf#34*jBB^4aq#lY&VZ{5o!XK^^W~o z1N8Ip-7p1Tkvr5S1gvAJ+_P%sSx+rjMs%F!ps)@@H?e%&`#$5`4EbHPc4dv$GWkcM zNFj7TbuvfymE$WAv1S3++L`fsH4u!@+PIE$xYi!{n1nd%PbT@YqqXzf!|}?{Cq>KX zV~PjQ1hV%s43E`m+D3yQUaU@{)FPp`hssBtq~wsZ)e9Y=G#U&LoUxg8Ll)-&z&W|; z_v>cnH+uge<+vT(XbNPlfSS~v*)rsfY9vl=`n}A?{6^*!I!Ep2TI&%fa!<0UM4}`Q zl^PRc0L^e= z&!!l6wZpg=cUpN$Kjdy8xl0U9QN+vMjf2yh1H#(2H3aD4+?9wsung<|fUHeUxO_^o z$N|^Vxb<~}m2?OxtEadRCIM8iWSU2{>oOSd4Tec%fi?@6LQn5;0GS>5)%5h9h{Q}T zPA#L_<$MfWu3}O%8`X|osg{~Z2E5upM1gs{0b}5R$c41iXqzG_Z#+)IITKVb7QLFy zU6K0lLJoOD+&VFz%Rr1-NpdG~o5G%{rgj?NZiQ=H}BS$;NR>4=12roUHjbZlX)n_I* zebx`GfaIbHB$oezy|`|1F6o;*VtYeW$`ONpz$b~TPbW`N& zi&w@N=r#e2(ZIQgiY{k3J1A{l6EKz~Cc;6Vb>l6dQ#3HDSSqVatAr_rBg!}g?WEUL z{9B;yf2iByfazKX*|f!f2Ol$As&yeVYhB35hP5t&phAs{tgG^%vR)bT6K=Bg{BBSx zeZeEJReF~r@}zb7QndM?@CXqslxbrQE9-S3j}TM+Zux7%4*9r$J}6h648CjJ?7|4h zqW~>e^){pu#~WPiGn%>YR7izuSI?+}vJ-f?SAtBy+LCm7#=qygbdaVgd721+P=txj-9Up%(Q(8NEQONM@fLMJlu@-hdjltZ8@M!k%S$-r2+62 ztJ`|A2JX_g?Zr$x{v#AEeaYhKEB#S37|`1f?nzEhCo&r-#SY5u6a(pqmD*dYeu09( zM!>A}ZuhvQ_;le+U@F+*va0_Gcn|V-0{LV%<+*hVVWu)(6jEN}`e1BFkdglTOx!lL zES?47K_Nk7Dvs@M=OwCl)B`-G%#Qbz8FsoOLD>RfA*fpWr4U9+)nlyd8!}^$mkni5 zwdmTyDkoH0$sNaMj{_n~-U-c=yZ_2Hv(xKKFhX_^s_L%`Kfu}`&8#%h?h`_pW?C~% zU)7^bmHR82+ab^DKfn;aNrsR?&|-(M-yLA)06vScO!(m39{}Ltu(PoL9<0Lvxj#ps z(_0m)lhz^6ETf;|S`j`k2J56ETG&DF^l#4xcolrtv>B4F-M@z7NFCG&=kjVGcy zi{OAeSw8AvTr>6s?L{+olVdP%QS8#XX4?p=T$L zT2-e3!*!mJ0ckxGhSKc;6nYLYSA>dW4bQ4X>rhupwL5+9|wlgq}#hFFBR>o&+eufZKb9C52K)i(N*~`A@|UMAos5W2 z*2cCpDqHgKvY3(FTp}8tBPZe7IPKaY_~2!ZoI+e1d##FPeTf{{*>C_DOx|w{*p2r~ z7MdkhSYOHAC9JgPNquD21Ycrrz2@;b}coMwYnpr zKGSRN3~v7PVrg!{O4N2}MPh3wqMH*Ln>=WxMqhn0u(jtQTGA12lO-#z68!Vt8Ec6H z=aKHjf%E=t;@Nq8+kWz5Arr*?_y_)ULa4|2{HGIG-4b_`C8*nPvVIGk+Y^b-#9fOo zi=`J5;vyfGk|6LaYSKuc_h~ci2;+7?!7FIWZ8W9vgKed#u}{W%;%fP{?-EW$g-G#^ zVDQhgJ%|hDeTlD_dp&5=xR*w+UU7{s`nUw*|MHM&4~I=pVD3}uL|HQ~AN{y3lwX$* zkw&D?*x171hBRX7DiZ4Trud9kUsLY0UfW-kBnL3PF1hwbNQ*`Ddkbrmi)|Fl^`5gS zd!D_wt^CBat4pVu<(F?h!}@JJ&>Fuydvlwx$|q5%^~8$y4v&)!GhbE`A=mS&p_91o z+5K0eQ{I&@vAnK`oEHt;gpD%e+Spe9D+m{xTc2W4yc~34L|iTP?VnNOe+eB)Y}sD& zX^$;fRH--Yy^GxRq6ET3GN+lZ*T8!yHz6UB7}CHubZyHiK7uK}&a|H!Qi8w!9vc9F zwYRw#o$^dP1Z(?7%^RJXja2@ND!-A<>~fxGH;a+ArmrJL7O`1$$yQPGGN2uV{H`nU zWLO1ho$5qaB^9iYzCA@Fw9qA2i5rQw9u_lt;(L%?N}$b+dtrY=*W#gNwu>x9&bOTN z4TQ;&N#`oaY$to$YdWy?>o>W+e@2cX-E-$t2e7$>C@)(H!@=|L;vgZ~KA$I@8 z4<46g$SMK44!eJw>`Y6$|68=%U9H;#GXl!HHPY@XqWGL!YLTIvq@1`(Fu|ug@9g+# z`DY`pmYsWHp=Fg?=%V)f%K7r_i$<9=Yc8E*>_<}WUfecp%RZNVzM8Q&+kf*q?Go!P zmdp6w_iFjW-trX>%fGAgU#ncNri?^IY8eJ0BgKHbyR~`TtDG=Y-o;aDV zGxb*)Q9n_=A}!=C1R@=>oh__oWH$TK?94Q-JSD1HZ8vR&_Ncx$1flbw#K_zRk)jOd zY+yUC8!W{5!vcf)IFEeG$p|%Yikz~#v%-va-d_{gqWDEPoIeE9O{<-f9KN$4aS?c$Giq>YO{m`VXwxmEyu{rY26)r}s$wGePqouUJ0HfI zcm*S;C-NO#GIzJI4s`t&qQs_h4ExhK?d!1JfS>_Jkz6Ndl|r$rW}?I%f}nwA?O#Uf zmO`Y%b+W&C{ks8s(D?&-_KM*1G)q^f1*rv znyJlpFwN6SffwK4j8V{t)4+{#*H{t}4dMaz=C2*S6mpY7a{bOXXd#JvTgRC;tXTizk8ICGkO|At<2M5G3qSB=GhC4k)0vS$l;l-+p-g6KN;G*tV}4_2iFD z1~m0e?xvVG{wLiow3giGD}$uVxQl{r%0o2d7?D`rg`$;P5gENA9$-Z{f^cQTD-o2- zLK*}?pd-M2!`1<)ypdGC!se*lI@MJ^04T>?W2{p>#xTDz1&s!qBM8>@CE(E2L>Dax z%W|km^b=y^f$qMOYq@?x);Q)jmZO=@;&7o2Z~g+{ew#ZcU+GLD&aqp)@uK-Ux{3R{ z{kPA7AIWdh?J(&suG_`agCRc1VbI>on*vwu?Dh6b#O!li&v!yrHwd_4pIU$<_fRL)Ol%^$P=zuhpL^)T|IT@m@QI$m{A*rQ(O+Z8uy3nEml^_3*?mP6H`^u=XF?(&5^a0FJX zAwrg^i-}McTaH0{8$p(CgxSyqYmwaYedp3=Vwej>4$h;0z1~Wj_&huPsHk$u^O9eD z<(GfF+;OUE>c1AwaABO;H{6V4nBRva!C_Z_B)B<#4ibEL^%Gn>j(+be``74eUvnzU zCNtu`D>%MuwAi9l{n90_R!+F_+7IQen61ng2^p_M6+d2v1HR&*cT0m%pmcGF2CqCu~iw-DO<>k=YO=m z797g89nNai#(UA_S2L{i^O7rVWLstg?7-JUJ^`{VtJ=A@8k3|KRq3b1rdv1Hwl3qx z=r5)~`x^+$?S!5lV+0YatdxmpiQ)I$x`T&v15dX~21e4SwQG1;?%hr|xTX&BqbjJA z96E^?)(g=$t-&W)Jlg@Vc=g1|G?h4+b|pTl@Guow=&)~ZMvgGTqf9ksUDt_x;z&s4 z2?=jlJslat2ydt-yx~&9rX3<4XmvziJfE)tOAO)e5mEfYjR7T1(uu*JZF(r2M3)q+^@C1xCE;web9I4sU@h&Ec=1iwwiXmmE87* z8V;~+>&&j1JKH|6y8d8mYopy7)NZ%wc8vzNpS0~B*xVXu_oc>8jav7&jPAc%++SHb zse-+qUGEf^?ESm7oh2lq@lGWoqh%pVvOpSs}TSrJT(KfI;fTDxo1S;$TZQf5x5N9IqGT3EM z>dV8_plYNpqIb9J+m*6a`2?vT?JxRv_#82@dZ9xlf5|jO8Sv>POk(uDWUV*p+50WI z>uvV~=z0-OaWtIaV!#r|coIjs{1YlQqhS&MiI>fKvK!9`(TI)AaB3Toi5F4BH^T|BctY};o! zr8yN|Un6iW*KRe;<=@YfPc_TO3E$WdQ|n=iQ1{gAkjhGANnz4h=Winn_*r0#i4^#( z|5y%yV-;~TYEY?NQz`h!s#b>#$nIIM5--uCADg|C{G|{lyC#Ljq^byGye`H?7@;W9 zOUIIf#vlFPJyjQ%MWSUB8OC`q&I!>suH9xIo>(Cu@uYJTT&o%Hu-y8}r*?Yw=NEBq z7iA6#^iGF13_ej;=A(2d+*YOqEHJ?Emn|vUmCz^R3&>n4tKrLMl0Uczr#xfnWHYag zuT;PyE?lttQYV*MXfas{G2Nv+rF{h_(-cMgK)i&p%vvCiac`q~NxRppWU#!a!TaGT zoh=mRk+^`wc7(I_xIGEd4!chwSz6>O3@Zle7VJsLUouTmBk;-j61MbxcC53QQc11q zNd77n;{=WL9M6A06Wj(71vmMm_3hf^KXRwiVbLlgu&ByZ$Y~LOiQ#xvDIF$yWogR{ z*)Yo7p)@XDz)N?I(ehr)(r&dVilVY&g43J|hJ}>SFe$6>G=P`eA{_UGe~##w*Xk8O zB92`Ln}#ohs||w5ZL%!datzq(%t*X7k7#FjAswEjlzrCjrq8|l=-;Y(pYW}85y#+Gm*!o8bT0CD(tQh^avC5WCR?kWBhL~3Sj^eiWvb% z`8?~uV|>6i@LdkjE?-Af8!spsEQE{3C1ZTjTnY!?CV3i@LUftGaX#0TIt{}GEGizUS zY`>z4e#HVgh`T-}ti%+S!|#kJvW5Kcve@urf&rDL0mbEDue+3$52%hEsK3ZBQ@PMx zYeQ6S+Y;6rl=EUHx(uG=e<-_9GS3Al3lF}?NU7;jXp8>5#!0CVbeEb6X;TlSATr7n zsWJ7^ntJO@ee{9}*@C_Wg8oHzL70iCUtq!y!K7`hPN$}FUlm#i`E{fliaJvE*oHi6-72NaW`)x` z^tvrkPWcm^-c49{P`hrcZr#81>)6BpLm(*Ft=P&b*U?>-*_`|cGs@BhLJPcv_RX4Y zm@H(xY$=ND1+q@Kzhs)lOmL*VYyI?poc4cFF4 zdx$SEmyf(>tZV*sZ;_W=RI;oh$z~AEzbj)PQ$_+)H%KVAzgGtbvUKs5>K9+{H+h}O z@k{<*uA+iihXPQ^SA(iAz+eh{obh2}_P7*uQbr6DDB@si8->^)d3q?<7!w92Swkhr zDKW^BA=(F{rb@{hRo&H~V4ikhvQ?g*Esidzkc{h(Z@T;|uPyZWx}~X*1@I zFnK}V|AV>fl|$utpnSL^_ecKfh%JrKBOhqIUD; zK5swn^W~S%QF}p;pwD{W``6*K-DV1VSS3BlEP9xLSJMh4VocSx*m-cRxJ#yFKE#l)X5{nThlm!??2&2{-|&(J$R{eP zte_kh|9wBu*&ZeVU{5rr#m+xiilP)Ll{SKSQfn*ZN6F{8tygr90!oUYSHd3JEsJ#l z(V0poh1Us!GNiX{@Holupy;k7X%;EnXUo!@wi?zC*5qe>W7D#-L4#jo9$0@O8A8

3)N^|vg5@QZ*Rv7?Px>Oi>Ka!c%dajpQahS-j3%x zFK_f4@;bW|tjG{8GD2#{R{kJS>3G2#B*GBbUo<9a94`!mM5pYsB!mCUM0q?nayuZ| zCYJz1K;9;ev26o;!Kvf@Bvp^o?!Xifvo}EvIL@=G_?z9&_0XKBWr~P*|1f|AF~)7@ z93|}I0_bJf`o%l#vanYtt>lYP%pft6u!%o^ovW_OZ7e^WtIqI{aB@90=fm7CVRo}j zy(+}Q{!*cGfofl@ic+|d%})HY$Inqh?fYD0vemr7K|Y+cLAmC)&`Lee?#poGD~VXc z!yjUet7*L80g0*S!M+edMFD;eOw4pi<1Etu^bS7N&O^{)M0O3MjO;-iNZG2E{m|GkjAJ!o)Sz!dTaN>hm! z7faZ;!!Ecz(z<8RZcA16C-z*$?rdE*V0u#Rj;d3YlI3if|4@;VsfcM0TTy1Simt2N zGSFa~_@>e=L*xEB&D2P{>(C^&g^ms~<&Yl9-EJW<+R!o=J5ct)p_xi4@l7NvmY_$d z=%!7KF=?s=EvMZVH?IhZ@e|J}iM=A57^9S-Bzf6M!P{09vBQp{7PYFVc`mLZHnDxm zP!!Eo#P!Hg3#o{mZnjn2{NlNetLSfsZV}U?))mQ{RrHquM>EK}YE=>1jEX$nZZSjG zykTUBX!qBdv0@v^wu(H*fB$(gq%uqT1v9@G#y_}~fL#A^)v3g!N5$@>k*KuV-M-|Q zivB2WAwfDQ7}zMIgNqNH4|h8hsE#_L9)t$u zS^u$E6DbzWV2iE($6`r44P3G6Ahh;O?7(bOKu$$enR*~&KW}d^I~r`?(bES!McCH^ zYE4l&e@$!{P{|9>tE0%jhNwzwTRmGW%7jgv$v(Vn2dkKYVuulpX-40Uo~!bD6yS*u=EK#9t)G*+3arMD3)U<5fTdm z#u1_aL{?d%MRKuZ6;uJM`yckif2PFPMo%nun%ZK^{*|V3o6nvFwKpSjWFNB~W!25Va2`R^YnA~_yG5kXDS>t+=3AF})>B0B7X`%pqY2_0qrtF}#j!Qc07 zbl5uVslJ5*g`Gqg$S%{fyWTLMiuulNRq~vr1L`R992AVaB?O942IgNTJ=xCxMEwMW z`c<;j&+w1bFP2mUg3MCCe=e3#Y&3%{mPe@Hf1)Q3Or%)QbGF!8(%PszLEpa;G>ik= zYqpOThD>D%`sV#D^^0W=6y{%0Y(o8*V*i$?vIEEt7Y+*$ReP~52{j02RRvQl^DjAr zNU{InWZ@C_C_?=Ng!*A9{$UW{Ma2?Tv;kExM0HRkA(GXES~mTslAQcg(p+1!#y2f#E1qNmjHDVUs3sOJ&=qCv`qP6U6%2|>gy%dv{{94=z(w} zG7KtITWGuU$U)b+G+5>wmB`XJK6se4c)fPEQ=Y+D)NjB{b4;*X z59Kggbj(JrpSZfq4e#PC?YFwsU4nfYjDn(2{q;2({c6AoCFEOD)EFk1DlL+(GvL;H zOD?nMV!OM*cmcQ$*}M?AYh+h;Y2bIGzn&9A<#d!<7gW3^g#!)EInbPRzHd1wJfm72 z6Om$u#-%Dm{Eg~lazykF`{m3#1@;gOcvy9m)ya9xU|k2dIB08D3n#|Gto|JIFOj}I zQ?ycCQrmmD8KUfid8YR<6LHUovAbtLL{M+?JJVx;Mwy}~i9-x-z%Fu3#nTp3;eE;< z=26h%RL1s&1nZY-L*g&;N*G3Nk4aLxPpB9$HRun3F9^mrxNCR`S|465bat-&3F=d% z&Qq!(@~&{)uR6}r(;MUByv|)U^Nf+69wYmv9}pWy2V~lR!H=NQ@T?@rUoN?Ex#L__ zVD`eRMZ87ytMp$)l{cQ11pCW3+`Q~NjVa}wQ2`kS4-e1M5rgLuzsFM|!L6E0GnA^8 z>Fb~quNgm8H(+J@RX9Y7#l$G~${x0sKI!j^q;b4hVdOgmrDYOpX`A_Y>0GvPqXZNW zvN)nIoU&$w(bToVZ^WYFSw*6I?uTR#Xlwb_9c|msVD6wq7U{BHE!77SPv=ct!=N!5 zJYNlsk#4+;68OO7VR^Ab3qH2P_Lk4%c~CK{r=TABS-V~=bG#D-dXwP)t5a6(~p;ZR@3#`%|#GoW%bHn zO8*IRuc)x~-G{C3?)<7k!DEGI`%nED!=j*d=WB%MFDH8lJvdH|tzw1_$)VTPihsR1 zmuS=CLxn3TPBu|A!Y&cZ$m>at%Z+Klds%bj}pjlk%EluHMN7~`mrWc33<&GEn;p$*V zjw9X*#|!gtD|tb5)LL;^etXvH$_zoRm@|hn_&1KFVd-Wi?I<#I6PfIGa*za#RXJYp zqC~45FR~J~O|qjXk=$-(n4#+JkmOyFb|~QG+9kC6Hp-KFd&1G%tY)Rz!&Q*u{9uy` zd|Z3z4Vn-9A$_+H^frxxE5UHp?$&A(FG|d(Bm-Ef*HTR%PRy7=_u*uS>hI|e)sE_J zZG_tf{8TrN)mEzFW4#sJ4CqE%+ZEgGy0XLp8 zM|dC)k*hO$U&xA<5m}Knqbp4XoA--`5J6GNC4OaCHn#AE$vUY(E*H&(sfnt@H;`c0 z8z!It{3X*A)yT-PXEOG>iFno|$7Xd93)P$3c;BM_NeA)&Lpioao*GqqeZa{biEBq~ zj&br<`+3K7_7?PNMp@ujZI^X>EebpoNzCsU4a{pF!;I754-vFUPNbB~L*+?4v@#EA zShlkMT|Ef;5n+)cV!PYrnUpwEB*DW#@=)z~*t-KBs+ot$_J_2%+bl^NNt0qHkE1nt z8|@~LhJc*2ZId(-Y@%8sR}X}bzN!)5JZXW?rxv$0I!37ewdy3#hYydeEZ_qH9h{RD z>ta931MJJUpk!fTo0H5EJN#~jEgEP~R*^teZ!I%hKXEs5!td-oX7`-mO~Hd@Zh$;g zx!&&FAlMPJ!y!9=%lVbc4t&(7Y}?1ir0VNyXzz2xdc)`;Ao6v&l5L!#HQ2-AvQqUywxSag|$#iUf`q<|LKlg?gAstN9?LM$&*g{VxD#pTUWr`4sQTC*?tq&Dy}#5>cmN1K|7ALI1oF2^1ngsRCPUJ*vzfU zFB!wvfn)*paxB&3?=se~h*eo2#H!R{RodEZ)lDO+o-0_DeFKTkW<&_L?gYo3N0f)IzawMp`}^IplbZG6!c*YZKw zlk5`uEmbkAfFK~v^Z;;MBpl!}9?5d@w#CkpuC--%7 zMhJ8ENL8o1jwlWgcIpr|2O2422X7Xa^%{V}cD~!Jum_iD(U)rTpbnWCckALqNx={-$g1F>Q!f zK2mA*Zv*)g->T!^crwgbDi^#`j6|||`nIL`dU&Hhf$uHF*XQvBMewVvNFx)K=U<%L6r-1Up8cLb4$DfB2mv5d zu855@h$3d<4s|dBZOvHg$9XjkiqHOKNTtfj>qg{PF>$U3n407Ly^Zh_RQr2x$$j=@ z_1m1i>+pofE}I&O74Nr@?a{*sQ7Jxm9gf0ivSd= z&N4chv_n8-F$1hnFVCyv^X0}@lXbxH5$acTRUu0jr6YEtC6|q{2yzovH=j5eE$bSl zS4-cTPYg6wli@=oUvqwPF=bUrR~y<7Ay$dBj94X-FWcbgMA%Q}Tz~T4bN&B7KmFfx zU7YscbKOoK|9h^B|9h^tI@c%4+tNj<0kyA!(H>gUzRtDW=;P)aT~=>+h4+G=_jtIo6GP zq39M8x?z%Wp?RmrfL&db4TP6@Ns_=p)W{jJJoe8he=W_aT#Hl|>Uo+|CBE%gmmoFC z$_@V?ob7dgd$w=;-#pvD`@eLyKeM0h%l~hk?Qbw=dm>}YF%;+h-`&}hXqvu_3n}!@ zJ2Vhg*~M%|x(9K*Wd&i7Y!%uzp6Wa5{jTpr5EgZG=dSNwC5Ia;6}%(H7cK`IR()szizAR6la~TF;@=r@A zlmG4^IVWa&h2(aLTh*TsvrWI&aQ51RpKrfV4%U8?S@zwo89(&?D#xR0!1nu@-#%Sd zkH#c#yh=*n_d8^MZtGKWwxr}3A;eSJOaJ97yNj;3TAg=Yv-Yv~+EGFM;+;EagtIi# zk2SL97(|F_+Xb zCQ?+)0)hq$DNv~FsNEOC5MXl%@N_+6&c^yr6#-I=xO5tVp2_U7Q4ETP!8!wfnvVKW z9Lq7`&&Jxs!$cVlHs0303wO;~8EtHpWytg$pAi?#)o|`0^0f(37+I6I zR+aPy5lmDOnF9n9ti51{rB5Yyy}jn$p}`@Z%h_aiEinY>&v9m*%6W=C~{IIeM!Zsp)TDyFz6CD6}yn2fOx$llgnwoNQw22sl~h6q}Rf-&BM7&R)1y zHZev|PXKp=(F5EKrdDK54sjpU`H}~CB$k<|$e+v|VTc6{v2-LuhPqZ5!8Qk;~d83ySo?aWp<7S4t^9D*-;mncG)+Xmv^21 zm4)}~d}QHGwrA`*lw0g9*^UZl26DPT5EcJ<#2^CKRAkH9y&J6LsnUj_)V-n!tnGIq zht~pS?`d-Dco*;lVLmZX9HZMWbf5&Gbowp3C(2a4zfQ@}pmv0eJvDp{^A}6;cr*X4 zPew2yTXm|K8C7fe6MZZi?*S}ckpe|Dsaa%9Hm$1wjts~okG0LvJp#MFgIKbrG zdZ;)|t2b1LV(W5Fa7AMO;MCwxszqLQ3F5>*M6;BOz+gDt+{R$IG+qG)1NBg8wLT*D z`Z&%^%u2PlS|7UoqVMjw8lv2+)<@z|&&_bMYwc`K_IP|6oJ{?P&B-$NZiOC2tQWr` zz>-%wCKRqF9dZ#w&m6Jf^^$Ydv1={JSMM;+$>lW-LpGox{^YZt#-}-?r~9Rd0zCPu z&4cX5ftleDo)Ro}Ks6&s*VlpOA=4oU{tvqT5I6)AI3>@ie8hHs0~N|Mx}H_?IjqWm z8~z{C^>yJpIb9D`-i!~g1?YO7EC%#DZv@isB`DlR&q3if$w4;UR)NE9zp~-Bg%cs% z#sj?%;WlCHU{LNU;H5LdxzS%N#p5+(^gHx5gMase|Ne^M{|oC!($>A&bXEKN%GrR} zi+-uT5dY=Y-Uoo#8_&D&wcC3t@Do&c)T;U9nw&f9^Vi_6+KVqwCGP)pBm>z4nJ*Kw znrxZ;2O#!24L-!zym@^JhU-J_Hcj7HwWqKeqVzSE<_{j_ea?a5m0#B%PZ*Z_{3Q&3 zzd6_E#p^LHD2^U-dsOoao!hMgfM5LR%l@Nf_1~An@N<*=&-gbvu^o{$+dJ_HREF4y$FK7R6)RQ!cCiNr*ZnpI# zmVZ@GGUeNP660U$NkVMvNwWXBo`m{KJxR3DV{jocI4YAI%%p;&D2lp=;hMgqu#gBY zML{$VYpYJe(?n1br6B@+C*fHlxEG}(0)A)Vc_MfmWh4TqE4fGn6;UQ4fVz^)MDRAs za*r0Sz`Nm>Cm_2!;o1f0_E46&$+NG8MlF6QVk|Z;L|KPeMf@vzcn`x2-4Q z@Ii0o8*A)f8bmPOHBC3yG`buX=9={-DoMT|%w@I0rXw~KtoxNI0?)r?@Z}}Z2kXd{3Wkp|m@+#;r!(LoGG* zpH8=`8Z1o+pFNUa`#M)A{+opuJVzmNyR(>k|*wy)i@eN`|#n^gk8 zA)P@S8Z?`G(+6q?kDzk6%9G8$RtJ`!`3N8W;uCVUQ|C81D{9%+7oysBHk@e@`Q5;F zsxh_gJ!%$E)Jw#)UjO~m8MrS{TXnrywgJ#R86xnKgXn^Pkk z@Gd4_LXAyspV&IjNx%gcJMg@G6~cy)9jJk*Zy4Vla@G21!W<_`T4(dQpsBnlT0My( z%OFH(2c0?D8J-q}@(ft>?TJ(pR$m`2Q3^@KF{5^U!8AP0jlJ{=QZe>i>OusjUi=nB z@Iopj0!uHwRq#scJx$%mBiup1x?vE`kWuvJ`<2v=sJ9GX(4DAC@fwGEIJo_`XBwQD zwVLN2ahRi45i|L-w1VF+liU{FMVLTEYs3A>a@DQ=(}AwX2jkqq$#w45;w*fljn_Dk zFt?`t94+crCCHFK%jIs`oyA%GoLM6IWje)Kl-hVrot(xNe4>0DAncUPk2eYJAHo7u zC_2Y~vF6~N6bh$Oo#THIvISmuCe!+$<N4Ub16#R9fkoA(&daX&XPtvumPtc z-j}=_wq8rZr!L(>EHVs^mm_BNf-F}Hvs{6f#&)?{$d_owCU;M4W3yZ(uw29czFa3sTIZAHdMR2B zJs)v6zn#!PmTS7ra((vq<+^WGeL}E4-qcI4wpo=fn0bA?r5C>s5xkO0iGbHzZx#G0 z^_~ICb+m(C>Q9!d`mf8?0?ReWx?I2X%p%Lxc*9|y&2s%Jear8+huoIAi*UiW%Qa=S zyA!)yzlY_z-MU=Avstb>yj*v(%Qdq&YsNn;*Jz^?_MTPtH3VQj08~Hu-B7ipGi2Hh zK`%GnLnV4Pa{YwYm8OF0=VM)I8oYdEUfwlng&Go;%8~PkUl*tzAWI-dYlcO=f@xTT z)-}qv1J_N}LDW~}We(MjiPl^83sc;A!WQBah_UH+#O$e!L1pga(PGcefnYy)ejwPN zaVyD($I+qud!3II9C>kj{N%ky_MmmX^y&V6hhLT7{r5Wm%eMUQbC-`&0kGK=(lv}foC;a8vhRH zH>mz+#UC}ZuR_3$KmWwY7q2GoUI$=Vojj-W$&n3K97K3myS#i?kDbBDz|J`TyxoK= z?E(QT>qFQy_*|Pbc)z1qC0a_^C;=<@XYm*n<3F`KzO<=*W2cRR9| zmpyy_irx6~R$RfChx6!jYn1<6Sibo*5zymrMOaVFZ*fuLV&1`YDb*w1UOGqbsXRt2 zbAaO%8vIPL#lU4Ny$~@fYwBf1>WfKGD_`nD1jUii!xRI|lTs-WR7UEp0*ll;Ud{K# z<4UFek-~YvA=c|1-<#`mlc)ccI10x#0w6OL0hus&sgJ{msS*bQWCnU+yTcTpJjgHt z><)1p@K(9ld4lngs7ux_N%XPA$|uQRB5>m3C!a*U$0ik+~LO%0$Q8VFu*EE8wppHXdNWT%6XJwwo!=E($j&o=hmBgwE@+Cmi} zi^z&r5CI)23?qWPXf+X-BZXawz!a^`P=$Aswiom*%fg_F-&43+-J2TS3V)+E%G^jA zG|HL>%MZdK+Aqt4id2ros;G>noTJYtlH2MmD>kUjvNh0mhn*sl{>e=RqoLnR@x7+4 zDjuo|mvgq(>V6Hx*i&TwSY`%>Kj{0QVk)KHsS0m(5wt+vr&<*hKnFSCw+QvKgJv^u zpp}ofCzCgcaKig>z^MkcLBfZ?ANLTwbdJBz865I?v>8E^T{Isu5H~U}!-t2yspXX^pYl)_NfHevjlV4Dn{h}g-&Hpv>!*(PrhT=|G>6Vu5$V4D;p zj)|wpHqp;RK2GI9;dxY2AU{1G{ZqXkk`yz*-znBX&LScNS27L-xb9er$9n;o!BMV5{qsxIB` z40JYMPF{F3z3Hdi%b;J%d#~J6^C_U`dvHfCTSe7-mayV1xE3U@KBazo;$1Kw&c6Ed z`ON`P1yl~hCqG#0=Zva=KpmxTAB7VLB)5ztJ)tO$E#N`^#Xwd8r+njIgmR@|%-uL) z#y(m7)vZSc`PxBNMVbn9C;BU>aR5DPDEIZ1B6V8b5bhbT{jC|JxlX-#^m$@~Z!e1f zjf->A)q>x!4bVimp)EJPG~Er*(wVj1H?6EQ3ZE}iD*hD@4-kg7;f5Q#^M7JVto5Gz z`l`wBeNs6Y9w__)54UvZuf)Ui`=*iMSyF#8JV>~-4I7`Gq;82PH2<>LZ;DNby8{k! zTSOuD4xoA2Fo=%dGXbg*P)}*5>IoGy6q{{Y#yLb4VJgdea%}}Po`v4=(2Giks&$Gb zP(jlm9n?3}?F5;pj495#=>P^kP)scgk{WPqU)yE@*!0UePU6%xh^JbM@(3b4SdR$( zAqrX{1+zEtdv9n%1q~a2pj1;QqwdryC6qQm9tdg@4UOU6iKZ&CLuf{A(D5vH%=AX0 z)|KWp5*)-a-fH@a%7;O<>~Jw?{VzTJ1^s~BomtqWBIvdPyfS$Olmf-8>%e?K ze!};|6r)ZiIB9Sq^a5zYzCwKIjlKSD^Hr#&!c_=y+@3j9F{*R?OJV5K+fRp5z#^DL z+=N1^XBS_^)mnn~as!+V$TmUo1F%ixD0TQT@+w46pnY8Uiowq50@hAZSB{4T zJ`*i)-GMrku8%~ocw7P|$!Z`mb%BpFf3Xyg_W|BvPAlXOwDANU=w`EagGUM$h_6gr zadqAc!0^_!3%<9{0atJ2l-+Fui!rlHw>{O*_qO^KR#)p@hEbb7uw+sB&s&zSo5oO4 zsb79ScMny+U=Iw>Pl#C?bT4`43m86P`AS@4M-NP7Sfo$|pJY3mTJ$Fj-^XbW8{VvQ zHTLzT7jSRK%$r;Lg9F7>lOneD{h=a6XnyD*6i-M`Q40{zSpc^RYnJgqiQaD?jo*H# z4QqovT2&a2C&l-*SGNrlZnCaV6G@K9EE3+euFnFz;EdhPz(mh3=Y4&ZTr)kRXrEL{ zh6f7I;Nh0;=0kXReqS9Ko+WKeh6f2R;NeuHxd6LG?vN}bbxT$FeVg#(RUCfIqpuX9 zZHZkN8g*1Gq&TIb8ahI-IkX2fal?(4T7(mF$Tcjm0>(F3Y~#^ltA#V0ywY47!oi6M z$Dm9e?_DCOlhwk;maP__Vf8`^-AiAPLl-ignrJUza|0IhV)V^aH}y^p^krmuw(8?w zTAS!(qf}xZLe~Hz>Kb4S;xy9qxYoeJ$qp~xDr2@*ZIW-$A_t~#6SS;`Z(ypHTJ=)T z%yPhqJC?TNaNVszeTN z)M}E6Y4y0Qxr<>H;zLDiysdGHI{ki7T{3W@*yKlh$EZMnW!;SanP6R`hZbp~iZj2hb{nLi#BL!^+@0~mO zx92X9K4^Zmw`1?y+a}+IDqUlBpW> z*v;_Afd+XWKj<)Gr-R`s?YcAZKjiw4b=`v&{|2yT=v3C4IRh#o%O;l*XB*Z4H=aLk zC~~d;4>8y1#k>3j%y}=iAx8 zb9*p3 zR^BPfW-F(AErwO6L>n`3C9uE8^wzNIpc6E+uiunHlR_Ze52M6mMK-H$A6nNPA!$;L zk|rnV+89C2WFx`^V33T|-t%k)O(q}@?~clcYXwV*I)ytu1GNE6l;OfjCI%8PU0|W- z@@_+bi7`XO;_0|dtk^Yhi0Ke}9l(IH21+2!4PZY9x?b+$4&8*d4;&mg+rKD?4-KN- zmt+lTU;{cR)PbWognq(rSj;_deOU&EJakzm?Yc|rC*V>tXX@IAX%=XbjD~(8UJH5f zMa=%-u&>aGmgz5fgaAq@&f~Xa>hn%2#Jw0y{~ys0fBwJehZ#se{KV*o`A9$P|C@f$ zA^nh!(EBh(Kj=U|DA{g6{7*=DmM4w|{V-jfeC$w) zSSCJqd;#)VRJ-0kd~bW}`!+cFIJgZOeFGNG!*+f!@>I=|<5N)TF98eA^;eI4`m7q& z$=-?IePl+h>kLul{@&%!#~rPTe*{cojoYEIZMS!6UE%Z0tUgCwhSA)I~uZ z9Up21*qoWc=UvVyU2x?!h zg@PAA6RQK$on@UI>}wwjr(;h(v5EYQ2J=s@wHrYM+9v))K&`VINd(iIgb;ydo!uxR znAIeL2#o9OMiarjCIinH?YiQxG&PCa@GyC&3&XQ*H|+usa&6B*BSUF}KKgn?gkvzg za+t8)fqp@I4%tDILz8Ye(C)A`B)(5yzdVOL$|@HFF+cv zgCnfButL{KYA1-J)t%s+sDw{+V{j$U-d%b|-}pj7D-|lu zH^Qns*N3N9NSIawN@g#Lr7#D0S0+yg+0)b2)xfY{>&Zv&rHl{pAJ=0GGKUZRtEG7H zf50X!_*Y~ghY^PlbWT8jH4Vx7qZ0e+{=l2^dFVgQD&OyGi&=$l5jCz2DAWrGU=WyP zeKqRtg*!1m&H#MO9@TaaKf8>DeD&V;=DJ7eJu~{60DiKu@;!~UIOD|}FvX1VwyC?G zp!w{+3Ifj$NF^Aan*)WX@o-@1AZWg*ua*oyETzcsAmMpDTp6i9jCb^szFAt)6ygp1 zCPBb6f3H%<*2HJoXZb+ojvXqGw_ocN;F5%EeJXb-{fS~3?qtA9fw#$y)Z~kI>0D9= z8A{I?V;6Z?qh~;GwiURA7qM#Y`|yvt#Ru z=W})_;&w;*DH`G-fQ8tSHKpm-4A#NvxvFIL4Dg)7UKjkabG>Z{_51WNt=KV+1$|Z z=H8>0!#`K49G;|qJXQlnBwX9fnDU9gC%#?!8aATmt3J*7t?7J(c#EG=zwF!wyfpA- zNiROGn|YE&yib2MH026r_?hMJ?{SE?_{0mAj~Rw%&uK^y?|ydQLWtK)h-MUIWmh5p zy-Mn2lYoP`a=1HU%51uPdNaTLqxT{o+GaIzE_S3~@Tv(470J@{tTRRLWlDu+=BzPv zsRso#S7W#^J4MRRgMOwp#1np4DjtKEl_Jt?6=X{#M4*nOW639SrSdUqvkzWqc~XDm zGr>D(%ZnHoH)}CYN1QR8YY$8lQFv;X?O5b`i*r<8(j?op#)7IHO4n_;luqkWLir2y z50f>n=QL|$DFZXB`y`qr;6EQCr((HX(zxl4av)wQ_S2k_$>~CN>8yh3PD_X&VwcG( znBlaX2nO!r4I-zS>9mpv;&v&mf*+mMw$9Z3gfDcK(}pp5TgTZ0;g)}@X;T+)x&uM= z6!F3zqANuM_<>drR28kxpd1DqBu**)gCvzYOoD&R&)HIFs+;Tp>}II#pdCZzu}mvJ zqfPwL2@XP83zP6tMyn|N8J`>neS_5!wJYDH?HztYXtm;8+q)qkbO6lP#7nobiEqsX zcK8{kwSR~KkEuOJev@Y|cWIqu=ucI!2GSjUg@mO7;E}y_xEk>sf5*Js?mv{-OaH|N zS`UHT>?eSChi{z0?xnMB_fp0t27U9GDUH{}I5*Hx2Db5K*G!1agNIuwgJe_x#Q^>UpeWMzo;>;I^4ESSyEmcgw&u>&83R%3-TC_=8y3y4J^AYo2e!WEGtlu#8T=k*eSDv6B z=(uNT(>aTN`2@rFR4igVcnEZJ99*|*?BKV>I;Tl(5p;aVIw$)Q*w3uaF(I8J7zs@> z1wW`qQs;!;1NimMai(Xi$B<&(Xc0k+m-=c%XRsCCbUm#mItECtMq{l6Bv)1ekX#V~ zkX-d4p8%38A^?&rtKfu`A_7w+-B%>@55A{?rj6k2Dj=hJ{=S z;-g|SgMza{pY^^WbR9TB67~n5WTrlox~M+V@F`?aQ=)dV)eXfs(3I6nFClq66;kMx zVb&u~ExLnH&}W-9?X|-il^X2dv1*604B-{h9@p|h@mVdRX-9y@hV^RCY=kZ7lFRDQ z264gy-1VT+6F&G50v!3=CctcRg2imJdDVD2m|B!bE)B@sZ5-f1GJiqa4P zWCNWg0$sF@2@y{=kmyL?kuG81IfnZrj(uX4vK4)-7<{rvmuGg ztnH2>V+oMmqx)sBPOR_4D`J?{DL57L7fbPYLPmG6G5$z6iu}v0HLTkn*nQ87hC6*D zxKh<`{ZDdz#kQwk&Q6|`NVVvS$9{rkRQ37PlGdF*jYErc_oO3pKOdfZ92TkGYstLH zOAe;qfJLf)G%8@__K9y!!0_yD8W&8)1>A ztsGarA|#VLNTT!>nnFn%Rd@F@JYG8k_}`im~DQ}ou2h`uaA#^ET`D54kSlu5d3A*xj9 z+8f@3H;Oxzre!>wKa20#qiqJ?lXD?iMD1$B`ePSPN3w7#E~yK>Y!*%mS~#t6u)t@{ zIQ+Kpj(C|knkMV1w6OzCGW}Acs3N?c`0a$DM6fzaNd)?K!f+zk6r~{oetY2$M351s zBLaPUVK*Y!7iA;@eg|PB5oAZ1h(O;#*pmqIqbx+g?sy{PnC)r)w7CG-oWP)O)L6_#>n2&+ z=HjO!A@G7<=F+;@240x|kkDWC4NbqRn217)izaCOi2w_=bW8wFN&gw)+q8-UcsxfyF`q zh3~fxDA>X+uWABry)E07a0-0v$1lHpaq;CJD5yQj|Mf3TP6qcv)RxOu?T%P}`0^{1 zFJJoR;_dQJzgb{-RrTCG$KsX-euUxAtByWDwfFVJY8c*q=k5umd;#hJb7;Y>#`$ZH?1$mzy=(WsSXJqI5QZ0jYV-5ju`b&yV7O)W z?w4(1&UE?=!}DjIo$OM+<31X`zI;LN*O0uA&R*{Kbvq>QSHkdf-?=c!`zXS`Zy;-g zq5d+YyA82M@b>p6!?mqo1Gt-Uz$(Su(J{vrZ;0yd`eJM!r}wocw(pWkie`iFSSJ7n zumW5}h0r|2!2rur0xnsFT4_yq&+aQ1XfQOebl39;0yrpjAp%|`pHBpbq*5ZFBK20m z5vey3C?olHgWby6pt&BL$^XyJD`f?V1`9B`UcT!aYV0ex~wj&XQS8h zT$wgVBAB~q-JFS0_Q>fSGnwDw=KT1Z!lLNUZJ0fkvHkzDPQHvs>!f%9Dwuo85X26J zUM4zmQtSaZ8*BSFXZFq>%=+jD$G93eXq3rvPfmN9QNst=nlZWFqL+6!P-xOk9(01b zxY_qAOgjDDPbcFJ>{)Rb#eoWA1E0@(ap5scx_eKqoIT`J-3+aOef2*yYqtEW9PL^6 zZoG9{cWTLU7_M8g>uP!a=cA9i!0->LXSN3I{1FZJSx`M{@zIX+(FVF0+qNx}0Eai4wv2B&Ux<3Z6(@Qn*>{9SZTa?{^YrbC6Vd;f z77GmdH8cP)H;^PK-6ksCrmOA$l{@S;xWi3_CLGV}r=FRd7TqAp&J2O_NXLNaaKTK;0@hDy4`3fVzr&;+WK*2mq*C z1t+A6QEIvo2Cu(J@967a@(9BJJ~>J>2wWRd#Lz8YO)!LlFwE-jh-U!*m6d!qtxcGE zkRwcm3Vi|i&Ag_FN;W9 z8+2VJ@~)h5-{j3j#KfY$E^VWyK~%vX_&7D!AQQ&e7}@N_*W`=zc{E<)n%;YLyGa6F zC*gVml`T<5BA`19HxfZvl!*uc!2e7Hl~EQV006&*2&$rZqCNPoDtpuNO#F@f2;m4L z;mq`rrRoEsm7>1eEZ9lJ?aSrIA1Sg&}C}l8y*QPkR1&+p#^JFCZP`wrKypH#vc;5R6{eq2bwp zsuM^hWzD>IA}{3qNEG!7Fpb1DwupI1jcAMGE?#Ps%+4n8%ldTH)m_)%s}c=oEv{c#fz znNZxBt=*TFExix^lFGtHTi#FmZFvIRWlHRYvpbu;o>vA}Y~1j4ZpfMTvys1K@`t$L zPdol`7;sSK%CX{tKbGA;goek*Eq(OI>nfB_G;#jZ7n6>j_#K=V{NCsndAey8{3vS9 zEBvC%w+Vfa5{-X{Z6B=`i*E!xUu+BjI}lc_jdpNvEuxkv{xAWpXxhC*ftLc^^bnmz z)K{)lJOM8$ZKT;M$dgKlAT5$UL_Tp`Dkp-hNV8RtFQtf}IFde0K5?&D;W z6|oy-MX=i4Z?sbkq3=zPHm$*W5D0&V5`i(=LIgngJDdnIp!6 zLQV8&Q&}oaK9mOvm3*kOFI~_2p&r69=tHbsdUzUQuWFtF$ae}{F#BkXjh%LZyF|K` z@W`s8Xz67{?Yt~j44sYL-5C4JydPyRGp@IK>DS>|`@?`V>YTPPV6{D@YXm{ajGYRS zNJC&r2;@Ms|5QXCgwfCgd07_Yz6XDky$_S}$U}Wyk-p9y%INEDU}MdXPZfcUl^gB} zHr4~^fzTs+g}qfz2d9}pk0L$&9efsSt1OhuP;?s^RzJyFE51^S(%2S9$9qgy*VH=Q#*QhcsI{V!~YNd_zu+Lu(XK;sE7;L-{lf z-%cW4p#Mn3%Qh0L^lrp!6?|LH+tl$=XvcH*o$}}qlYPf?g&i!3Ug`Fo%HaP@voW>_ z7IbZ?;9qliY!W?II_wy{zkGrE>iQ!FRHXg?XN?PQp=d0zb6Oyc9H#({%qI)|(j}~% zG%bs?iAGla(e=htJ4NFpTr5>-1^kwcEFC$WreuiOBVNN;RtICQlDL8m8X@bY5BlDswQEsJGotm5_OT4 zG$77mNpF1`(Il^}&SG;%a2EeS$pmVvBkF_Hi4PrZSjVmu~2e#qx=MzQjZ=aEK);D8SeVbgPNMV;2+4? zPwO3CD-8cLv!8oMd=(R9E?wWL2-J948*XHm8=GAh6#7PR6unmp3W5^f2xgSnOFh*3 zFb}9{inOV(NgME5QdEe4m|;RyNBm6{{S-vNA1(|dg17zDM4%ro>`DY*`e}(1 zN&SsMvVn?7um(OB&jgGGD+UMB8*J;=m}Gr93m^ zB;q*`i?ugwhJ~dwjs+qGs_$&5%+5mVq)?;^lq@Ag3O)y_9zEWAX2!Ng|EHwS{ovT?wCly9r5yOcV5d1_o5kM z{$<)0w}R+4K%}rnciAHjG*-_LRhxUobx-j@^?KF4$maB5;wo5*7L?aL2|hcKy; zJZO^v#YSnA1MtL`-pFN`{d5kYk8jkOr3jlJe@f(dI0SLX$6vv9P)t?SFVJ5p=0~O! zYU4Jh6zVSXGJNAHx&;(QJy2OIqUufKDx&ya-bpJU9m^}*OjkjPR8DKIMC$GlGZa-p z&D5V>-cTD3Jn?6&AOCo6Tt~J>{M{qS>>4x%*yK-{q55h?Rcl*Ya|ppO!EjH`hzBB` ztF@@DPaUJT|L z!eE}UaLk&gQ??XAn*Td!WB${=y>APDXAFA%`EvE^D^SYDZ*u?1oa&lx*TLE~T|ZQ7 zRrQt+^T2$y9PMzcx_a###H=+s^Vy>gPyeh!)^4q5ne#7<3U~z8u5v^EW@OBkB5QXf z>y`Xm3!V)&41%|VLCD$&t&fQO_CdIvjj=yYh4yvgiP~>a(Y;ix%#AGzwTTuE70D)1 zPet)oRJ56(qWKsVrTq;R)dy4*>xNO$2saKDJ!Mc)`(mduf{N}isHpw7sOSzsMFTf( zeW|HsgNhylD(dq$RAkpM%r$|aqLH>Cbaih%kD#LWHmIo6-%-)$FekGO#-Rum6kwb} zBB+W|5&=ZDPZNPIT0;a7)jo?*(N8Msi4_&C{wpf-2UN7fii!@_h6-1P1N9ESVxJ+v z1{FmL_bO%02r2@5Fz{Pc6g)bgLq%p);jkbpD(Y#2iYy!|;x}bb(I!AeYyJThokpk# zaur!ya$i!dDr~&+`hQHT^&Dmj>aI5;3utO-l#!4}eNi*2|M&7UN6{p8gh zN%?GAe^naZR5SZ2UHa7`wtd!3k9!!k&@sZZeX5jXFOHF!TXL>t-T<>k+C`v_QoeKM zW5id|8%hH2cVLtIkGNc}?cEXH+q**yTU~_CO8HGKxy21Lx(Hd~w_8-s_SAO%I zp<0YxlFb{#8XT#IHK>+KThtjb(x(o|^-LicMwbkg-*Hg+O_kMy%5N)LO4N^fQg4#H z7=fS0fl&EfYzLL!-vwscHAR)*5zhOmrZaIRcPgs>;f5xSA|<%I5dN`^4R(979$Js`!J>tHETPwd)2J$?kbCStzHLjDA{7vPzc7$jg zQ~8}XbX=>7^<7Ej_r8JF%5OUG54Q3#Q%NG_!Lf_i+RK_E=*)zb)6*_LFR)b6-4|=p2vN40L|~i5Nox zWffw~B)}=zq44nymJ;U%q_HDroURigd3Pa8httlJp~LCKyyVClNLQ_>dfU{76^G_n zaqb>sQB(F%7Df$^naP5r;hM%5d{3Myg;U^QqPp1%m(%nY*sSM$#sR%#u47i@C z%s+A}{|I0O<&qzdZLV3c5;5v!O|K}-sCmXP>Q%j1o^$FH!>C8sEB{!vd^W?VS6SF@ z&lb$6_uWCls8`Ac<$tS&(%Ju%$EtJZf=O4CXI(&0X7SQ@Lx0Pdi=a&9%XW^Z+6OXF z=FXnD=h++AF;FJE``tsAA2Cpd-uUc8I)*Zz%9uz6f-=Qi?apWTDfECWCEz3rS&F=4 z{~`aB%@14E1!Ezg?p`QbDAG$k#eGytytG;&!}=&!wuW+!&Z75>VXY9%DB^=*{R{n9 z!@4qZs>GTg6CzZRJeZ^?1rg{&gkeOG8l@%zNc8DS1goR8MYPqh-W0_e)|L_>CfJyi zWsOeW6&Jb|)JKMOA3^v%{hj)EXku7f2x3a#7`xUe2gBM*5>iMRW5KW%C&-h%M4b`T z3I$h`xHB@WLml^67|OFx%D^&0dB&@c4C^n5RV1x3HmrH|v0-hQoNOoTYhzd|>ia5A ze-DN=D&ZB_)d#~`uld0-U()(54*kjR^;72)ZZL52*g+(a^*?1sTMg^LzZ=&1C4qti zcvt2{=`94#wUD_G0Tso6NCYiqQX)`B>8*m+GVcRGdG^jtFYQ1K>$Ir97}k~1mEwNR z#IR27t#-zS_2X&Ixi*w%v9f!1{oa$?G7l6U`F9QS~BFoI*R(1f;;@F!zXXBKqenl%tt<%Wjmv z3ndX>2hFUEgBoWPJ3xQ26pyzQUTX`&89UVv-ZMIcE%Rs0)JD7vrh*EJs9x3Z|C7$_ zZ$0&xyXiHA6mf5r_n^pd%(S`|`clwj=&!PdPwbRukFnFqxN4fd;J%T2Ub<_mB}0m{ z#`;&(bK+v+@*a+6XgxiGO?%`8)53NAwc*0L1M4Iax}MQl#DDzh2sliANddnSDqgFs^n}vuQ*+CGQb!Kc@Rv+dZbd9lhEYLy)Xrb!_ z2^hCrLmAh0Qc(+ri2}kiN_L)!;Tc2q37TKY%**Yy%*$Pz913z0n-p~?^Ns1ZoXo4- zi{@Lwu)of?B*d~}LR~^E%#@3P0lfV%A+`PCgzy(j@pvoY^?#gxvYmHuee6>GcFy5; zEor@ax!OhJz6R;$=CUG=5JBI(-f8YF9U{bU`5}%R|3{wrlt30en3l<+H3Pt(#5DBH zvR{43pt9Q?p4MHoJC^i7i}p;$#fsFLWS@`&v0N0*#nKI79zL`Wb~%pgzmy9M#(WBE zs9}r6`M9oX;A@z%CZyXgnHD5rHK$ygWb*fuY`ktQ;BTKOqFwbaZD&F1DFr#CdT3@U z(KRGKjrN5<*8`6VyXW$+LUwv6iYeN&38t?8E=$r769|uobQ7LNxq~F)ova>v|2I9B zY}I2f@EK59Vn&Y{kRH?gRgao0g4qsc&LSin|TWRFOG(CQU@F7=L_2%TR=Z#w{splJr(xV?n}sI^03=(_d~MR z;oj+^B?imM>l}Wo>}AGp6}{|&x%3oGFFt;&?xn(&1xf%; z>-(pL3+vsn0w?3>Y-5^oynDtby1P$?SVgA6VhT1YmSACkZwgjOytQIq^tibhk~r4` zMlt(lVhhksY%@rezj%oYnOs?Zz7(zrO^oqP)BukHJFa^KK~_^)WL0iOR+2W87-W?w z*FvlpHZ?&3IF5KbNmx2**^d*g$ZF)jLRM(Mv(-sw21#7%u>0Q_cK@^4?~;3w{mvae z2YN}#;wwZ`pv;CTXF>R?4B@LvgsBSy%m^J0G!flO*Cj(E@wo8Hzx+!M=tT7^G$HXdy1b<1|TRTeQHon8{oDF!_K> z(jbwK_<)ON?nyvlNkBQmPqGvv^p!`YkCvxs*6P5+Ffmq<#PkFGr}@PK_Z#dq`$o_8 z{OkNuF!QSbrP==T{L&-(NuUt$;oU=(1OM2>m zN*BoJTL=Kd&V7@SIzeaJS;x{HA%wsNW*$ksj|dX$|H3m20&)iUe9aOmladG_2G`d7<2TcAQ0} z!hr^=mrAqRBPC%rkD$pMbD}X>ne3I{5w>E_@38ACVq@Eo_MoyEmJDrT4!FOPHQ?Ox z{65`J3!C+{IQe(5S;L$`I`7P21=HdowDiWA7zH&yey{MnU6{zHXXp6y!qC#&V4wDQ z%IeSD7)@AG*yn7Yx^j0P$Z}jSEYvCB6ju8oIV7!!%WdTJX$p(;CLo-bgS2GoEtGF~ zali4r*qDA5;bwDFua)EIKa<(#`q z+r|+LQy}&OTDlizr5|>q()}Xf`Pp4t4qV8}tdVP#EzZ_Q2|jp=Ug|5Qu|4xn8m&D4 zGh!f2qruUp=*f0)%Qnd)8N(1gpE1m{KMIfeMw=mu#`EKIh4!Vl`wvbFQjU^be|!Ty zdk)k0=bpXeU!T2C3my+W4}{Dek$HHI=pfsy+go-Nd#=O<~6{S=_X1aiEF2`)1PcG{|vhf&kj zA{#&WUlaHluR5F`yJSATCLt#0XnqFtR(cH8_fUO8DN$lLMY~F%dy0hh*|5uSCKdcc zAG$Fp&adHi@aQ>Qt8@-7CX1VATae1V-3%K zWdYDzHHWoGTH&-afeYKhE!}n|uMzN%%#ak7BGoxe&9AqqHFA41@X2lBQ^#hOJoZ#& zuN(UZi#gvz`&mxS;!I9<^*HClo%L8hD#s=rUewmccP;kv#XDo!4vU7(%v8Dz3GiLn~m^$9DJDIfYzU1g+HxgtU^_C`u~!dq5;?~D|`EhS!k-NvoSor;~3aOPW{7nm6lOJoE2tnLjFG`8xm zF)VSSy7UOCy#^3(O{V555AgwFPWVUt@Ux zzcq&u(FRnGR|DfR@urG>Nt!ewx#6!k3soMoI$p7<=Xlq}w$vU-+(f1#|2*d%8o11((oJa2?`?+X zeB$czWG_PtXoZ+uz&&=Q+b77T3E?K6+oxps2joR=lejo1Dh09(^Pt+q(W@D|*%O6i>%erKx>k`*wJnO}~ zPKz(N5@SWas}beomWHMN--kPPcO%k}@BN6UrMR1$Aq0O%mX8@Xx>tYiGYe1Qdt?ue z^D^hQ5bGDKBnGF5hpA%XK-RPcO319}`bIn463XJ*fb+aiqA@p}ZaZ+Skn>qq$YR9V z2&Lnf3xZ4xnT^;n3F`a++sqJO;CwXt2X`krG zPuO+yCtf$Vn00fU?f%3tzx}BB3Q?5PrMy+kTHR>a8pXQZt(=S~a}F$`bNhu*TW}CX zI84l5Z@gtt$#uqkbtuFqK{4fI$ghb0VksVPFcT-9C2aa8LAn*q!yIExIBeAh_3SQG{?JI= z%K}YUCn_MqV1I~Ajqgq~ZQHHRUWXv5Cs6)opn6tl8Xl>Gt!JvX*wEr$+I|Fa8+TW$ zyerU@nIc}S#}4GAWnS<#SG9>BP|BMrgiC2kF?aqIxX9}Zt^|cX5+G($VCtXVO7*D4 z{a}XQGu0Quaf@~q>!T7(4mfvC(}or%7$NPvp$@J!%mJTF-)mvPuN?WYxjsvN`~_WS zL?4_{!|T6OI}iD_hOY5L%U`XP(q zpoBTt4k0aV~I9$dk| zIxcM`Iq(irS55~M7x`&`o;66{MmblCk-5Q75D1x9Lm>5cI{Zv= z+F6HO+wM~}hxm2k2cci8{Z)^7aNJVwcWq(2bI>9IDt*9aNHRm6LmiEB9-ad^joiQ; z3n6nfRrvC(I?bV=BOhJVXqXP{YS-x}e1zL9c7`8BEgHqmE56)RCjvV5X@$ zqYrqQXmweo8#KG$7Yg$spYJW`R{_6{$*er@lh1!7>fxop*3V2ymumQtijb*j()Wat zFTFQJ+~-L|DsMDDl;I<5t)v`|2&RMoT0G5E8)kL)JE1U59A>~bW+p|zVwtIhsy%ak z_*`YIL)ZFj-#g?{UT23iT1rx@(}=>d-W0RqV|LHFE~|!n&49PL37#&i7I(-kLPrAz2rJylESbe zIn9v)LNgyk_VwXozF}goTP;o&KXl>MXLf-wXhdLf#Qrhe2o;>>K8 z`u_R|t<{#dLNzU#b9eti7IisHX*V~>U-)@0lD5-TO!xV(rNLpBVd(BRS z*QOd|h);(>+cEW=( z`iO&H86Q;M9HcOyuVBM6*dy2-Y7?nvHMUVXF3{C^OM}+~JY#J4Pqm#Bm?)`AyCB{| zNeRzsFy9~3;L*R!%|JfuG(vGvfjLd`=-&w>nclL*e_w%>svk9(Xws6;QWg4$B9Q6B zuiag_Fw;lI*)T=A{f53mAMB_8(nK;amW#_;wipOUB5${jDi>!McmZG)S&(z6M34)2KD;lW)-m1~~2!8&jzOxh6NZft@DThcaA)YKoo` z!?r};9ydd#M(5;|sd2U5Z}m6#>;F@PumO9|9h6&@jj_$wH>GjcZAk=8_XwmGZ;Mmm zTd*&2>Y>WyE$&O8_rRw1euiMX%tk3rqV#y(VqSW@I7QRShu0O^2CVl_mwamk_tuhp zXPbZz#`StRV`wZ%bNty)0Y=4EbqWr?HVnYgu*6j=gK8@&C&X<3XvVacE6dxuP; zw@yiL)(~W~w2nBau?pTfr4YfugGQ_1y;CX?#2vI)1s|N!h+x)1>R0j|RZh}V_&ZV# zYCbOIEcC?o(3PIGz{?{1$F;VpIe4E!jPpwW6S|QXlUfdZ_@EEfEF@)6ayR#6uQR$<@NS{w7{?7UOOy8pi#2PF z^)lmzjB!uLE`zA{iOwN2*{5kzmOLJj=To^17(%t1xu@N}i`N)O&cjglB{khoCkhBD zaNoS$d3pUw61!4EV*2Qm##?_gHN(qEv-TuD`yjhS&+Cu6i@88R1+^~4^wcD#g)gQg zH?TS;=A7?j2(Vl4mj+%opAq(PYc)``7wzMo!Jdbg7*l)M!(?SEVE#65gDJx}brOd( zf`c07;2v^tXiGS#V-6nY4od6?cUaz<$pA4pyNy_@WUX5Jwh_BVbJLjjbGlP>=prw> zn>6glytdk1Y0MnITl2!=LpBI)p=#cl_+cH zqF}`ju*ZKlQAaHc3#+|+i|8%%&1&RvU7Cnm@O%|P3!e9r6jj*?f_u0<_tKGHidui| z9VSKwUP1!^J6R&}^@y-mPmb?qGN|9uDgUZT&#Dyz_wfy_JluP>n|x2`fGXK;d^bnk zpoOBN&FO?|sx;)n?fnGyLg*^;xm6?BDE@VT;pk!+01ZNL> zagssqHZTVA6!s!7rzQt=O#KnY=tAHiZzDRWL`m7Je}!}Qp9tst4rOyIix_5TpIlxG zL>g^8S*N3tatVv|j<~K^su~W4boxQ$Fs($PMe{EZge+!~ukWB#LH#W7Kq*nSbt@=a zW4x&>w$p$OU0cA~(jCFB23V)AemVE)vycysP|o%7KX$fl`o0~aO>VaF`qw#|-z|O$ zOlcv@n*`q0PAmZG57WBQ=dV9qv1I}DBHOp`(%#M?2fywGY~nxpq@Q!?(T~7NO~1N2 z(JObkr)FoT1FpI^=Rx7Yxw$_gc9YV57c~b9j$A<0tg|<^zqE5|bt$kuR$gpT zzxs9FQN;3=R&i(f)H~0gA)qvJ&|1weuVP;zhM0&|A`Nx|t zK$LUxPr(CA^Z&SovMx82?|u2Y;ff}3tC{bAJ~A`@M1T|A>gcLRtr{F}_ZvJ_e!VSw z6RK-|KM7AY?aY+@xg9&LOMu6;>DkXG;ue1n+*fO)U~0GW*o4os?jWx5Zd;DU)#Oh) z1y5D?d{_RGz4zCChT&<7faAM{l#W18wQ6yBC(j!%W}~O-^FB8jeku6#JI+^~d?l85FD|0%$DtA$0q~yfAKl9 zYSH9Ic7U(YUWm6DnQ?pq3sT9z1=2?0roIy1G5@XHSN&JP316 zeeyNDtEu@6XRq3}e0j6#rP?4kd*g#k=JwdWztfT;xxU%!CzlIO-8t*ivTIg#n+2!# zW!~-rXMg+X%keh>uVy>LaQ~CnALT!oQUU&5dhz}Ciw%z_UWBt-R!pCJeEj~kh*2r) z*>U!p;9?Hl+m_I4yZ>IqUlehn<>W`c2}8=4}zBV%CV;nXXQ zUR`8r9*p0s>Y?9=qa-Y^2?nc5dB9R#5cFkBIXrUXxT?>V-Y1w&CC7d~YWZVVb&vP3 z80MUudle|$Kf~Y& zH$R`wk=MsAyW5E&uZQXCwuF$^D=~TfDb9~ArRlVX>P2s)i{k6&hCBJG1*6sC9cmXF zv-Qg<(BUoJ&&Eb>SJa;3o!iQZ`^N|-gLs4{<}0wBz72&gsbbpkt3>uT6i3XA{kjOg z;4oWSf4zy+x{a5)5J6UdzL^Ln%A`b4++S}MOp8G5U4*6Krd!Np0p55kQ5Y^^SJP{D}Z61g(NyG6fOn z2GiC%+ARwq0^?w_Rj^m4{Gz5C9}&z{+)ERERU`gd_4wj%?r);){T;HoKPtxN{*Kt( zpEAbg{&H;YPZMKve@AWZPZwizf5&X@&lqEKeGf4|t=pC!iT{!ZchgMRhQ zJk^*+et~&f368hm#y{YLmXXzH@-hb_@Il@f3S7Tc%)##5!4YXAw`*${m#PH(($I1h zM+dxqCeZ<}ttw%otUqls@wOA=g%WiiU1mG-&tl@g22hNoEB6A^v;^F+Q zL{Kf05<%r~y;VTf@g@SvhyZ;Cxt0`RH*zhh1C2z$Uo4Cyg4F{}M4(?R>`4Th23m-K zzeL!Z2r>pr)^Kjp6!V7lp~65PnWC>C5?7=^Yu-YUpo73(S)o-li_7x0T9u=Mbc-`C zjUXnjr}GGvf3`oi5>039)=RJvraiyg)O32Dk8v?nt3(Hk^68`6sFo(;{7T($1Kt=+!XdLRWh{l9tA7MtnKB|e8T*eh0<4D2^gGo!v z_=T(0Ubq9bWPnN$DG|YU*op46g?w|3F5!q^DxSi|nE{Y_l_`erTW!YQ#}iFF>kDdV z323O$dTE$HPh*+c3fJ>^vqr33_1L8+b!(i zw#Vv3&_Cp-MjD8o!KLk0b{#`);yw)dV$mWyPgnz;P#)O$EC{JE`DN+mjSCa!U-H<8 z;u4ws|Ha;$hc$I|{oe^g3>Y*}nXw)m86_erpkm@wwGL>l1+^M*LX=TOoFD@N0t!W| z2r2<7qXLRD3kg~WK%iokip7Xl5D+L@WDI^NW*T;py*_n+*^aYP4Y&g(8d0-NSk0Q_J?XAf z0F#jtV@S`Oa*^TQJF9l8HKNW-DTjp4N_N0yx=myR*{0^wx1Z&W*{RkHg~lm%%-7&x z8UyVWnX1^wOvx4k)OnPa00!+~ur`66EtIjGF7UhYiN81GX@P)1jT(iwE~ohuCJP|U z8+3EllAms*O=LKHZ)je|?7eYwBIAdxap+_2E%+E1Hvkz6dv2y;=g2ofdVz}>XIRFo z-2*z^3#oLUosp#Hnl);-j8sei<>@yzU$dAqUW#U zj-bd>8e}7wgc7X6wVrcPe07Dic|zA3Oh#Hh@c*~6*!EBqfBtg>z8tMb1A5eY+<)LU6lwE?<}7j7=u;o&E-}`W<#8?=Rq-04 zCQ;zzxrKkbsJf)#0Ll9t(Z6q=F89MQFQbT}4J;?`ad!w6o;nRF*LHbC``_wMBYi|e=z zRCSp&1BjvLu`26=AR`YKgoBWDzU)62fC22Fda_+sx_*L!b1)d$bVN1Hm5_rJdk!DDK!?WNco# zGatb>aC}f>Y-Syznp459D$hBfx(qFhGW7m90fy=kd*|w(Mot^^z25IXQw~u}faRKl& zdE*@@liLTIAb*2vM3zJTCFAGF-w^tu@L@xv8}c`JmlYSbOf3x=0Q2Y5sxk*G>t+6i z4?|5^fw`Sw$luU+H*$)96d5|9_EySFx{u<`Xp zF5}w)eg(m7=chYEHbqC(qRrTq}H!{ykTkI{V(k$VBB@ z7xm|#xRedo-j;GdcE_z3^RukBsbqh7E>*qjr%&yutBI_>!aUn{m?qZU^h5PWc86KQ z#kB1$evw;zx#c04zl{&y9gsKi6jBwK+vLl*zDNHqJZH0N(+V!n_jsQG=213>KifZ2B~z)p&9U?wH8`Tx z!axe%vPJ6FIAwFSF-lOiH?4HGQ7aeBL;KQ7jx9U)el(cpwG|DiH*{%$ay)0KMdX=l z`{F6I&N?+C9b7l@f5K2bfAbD?Iq=ii{~gzgJkia}|JOKH&~o?}jujaGSx1Y8;u2lw zLICo=@v*Qr^}%Cm^N5e7kEz{;eJp2wbx`Ree@b&V@I~0ou}hnQ3wRr-3AkXl_BdQ1 z+Q63=?9n#D1>z0VL|nr|donJNZs5xcytFNF0p&aDWn9BY+Xfd9-|^)I`?TlvIa0y# za1fqe-a(I8OwPCcianAl_9*H*g&rMN=u!4}3Ozci(4(5~6nYe((4)@p6nZ35=n-?H zLXQ#^dc@kO(4!=U9yxAQ=utA+BXFg7)B1;9DQ?+g*Ge472WW}=yG8PexuQ+SV^gzH zA}`3)X5xa*jU4&JWNYi;0_G-(yx_VvfeTohIP!_Pp>2W-95+ei1-aTRT;RQlBcGUj zZChLrwMlY$tTx;FYsV3!fkK;|JQg@C2lTMuJMvIKq1oPrgA&<60X}H|1`bjeAj?Gj z_%M^E{&cMT@u7x%(h+HbNc9L9vDS1wQwtu9f>hu^1=OOSvyLPlROKGPesn z-A+-O2^qi_gHa`)Bxm({oXpQG&tFJBI*Pv|V#v|w994y<8*daF;!bR(aug)Hj{UTl zhaDq-Y12)}r#$XfiM-&WHWL@{wsQEmMz6LWE)Z>%$P4<9Yu#T?UQ6e9Q5W#F1ge(f zf}ZcPae*XIwGtQboN90ZCsefx7lb->;sQyiY7H)kaw1NWt`z?H9~P;a8f)9E3tB-w z|c?)QaTNg*F^NWE2SqE z`&ZFNE2Z*tR|<1W zU6|aJLdC8W>J}K}#mJRnCv&BQ$z3Vk7Shp|cVbtH&|!4n!qjJYkfqg1rFf7Rx~FlG zD~0m8D~0c6bRD}=tmUp0mfV%%?VhL^M7mO_>8cB{VD@BrVaw&lwBiTCiMDmtl?vcm~uIn&y0Sgjc;Tkt|^!AXlc8&_J zl;_x$A_7-RGIpgnfGg!UHRW{mqr)f?X+5->y}{t`zD<>`LJ{V^@mX z@iT1fN)d1Px2_aNKU46IB>ml$vMUckcO3S22CWv-MJa#spbAX*&J zD0ii-oX*lrIEq~<%4bHe#;z2{t3N3rR|@fWSBi=8z4;2R6alzWB+qyvrQEfxq$_0r zyHe=jN|}UQDcQ)C67_>Na;1bSct6;~u9Sre-Vfm=nJXn+(Ul?|cBPmoxKi|nT`6e_ zjfvZ2u9Tk@8@nNQrLd7JrF}nK4fpS^lxL^zFa6AwB1Ntg?%!M~!hAz_l^y?qD@Afm z?n){CXRZ_ua;3bao#*Sjo~4fMB2t|+UC$C@!Sg~L8!=bUv+jop#yj;taK{R+>qr0K zESL`-8K;2^ZT$brmP+LVy?)1G14~di+Smbd@+3<$QQ%) z8l8A=cg&LIQMYBT9AAKJ*lbT^aL0oR@mTqm{sxzcUB2g6*9IRtinp)x4*$6Nwl2HqR<-eg^hckjzW52I&Zb+l)v zc6^M1=RX^RQ}Yb&nXYaD2gtGat#O%Lqto!*F0Hx|&~UJ~Vj?^)7XP&8)#1U0Th(Ch z=kTgM^wG7*kKkTe^{Ozws{ei&JczTBkE+K{J0J2I%Dt))1MsNeY$-)$}&{cGljK##5}dFynL2XoHsaFI(Il< zztQ`uDc%Zj5%c%yW2^QIwZOHO$9X*v^mq({W2>frW__^bZ~g6He&x^I_1uDR&r&eI zIk(7XOXRGFD0Th!Sq1di_d`vPy576gw>;dUy3h?Ks?E|NJTS6i#_wRB*YL5k_0qfX z;7j7$wp!O^bY4OEme>u+V{+0il;(r^fbYtX(s}KcJzze0Zti8H?!E%#OL96p#i`>+ zT@ef%gwKdlA)2l}0OZus}(YOAd&1UGT07Y*FohKQu}V@+4Sug?sFu5$DTZBK@H zp2~$;W#4Ubj~@!m^MYwg^4XkWxu$vtn$@O5P183v4yuEPkiR+lu+s~NFS}t@y&vtG z{%k>3eFvCJ`%JSM)ARl)0rSZbVeO{p^6Ow$sR4syX9L$UQ2XFKQ6BE#^u0H$jeDL0qOE9ZZci)DEuozR>txhED>bs!z zo=Lx7%G{VS1?8P=D`+1+6KsSQd2aEwCgm;YVn#A9{(p%v(j2FvKjwh#8lTi(6E-9{ zDlKuO9d}g!Sz$-MFU#TD|7PP|tMLLzUo9G(@~rFdiPH%CiWrDI<)5DNS0oJD<-j73 z)q+>*a4&JzWYQ8EHKB0y#`GaGwTuZHsDi!sUz*&fgUioHazVXn{y4(eViBZ zR;}qd$F9Q|OG9!EUQB9suiFj0w>L!>)9`WrrMM676=(*# zJLZ?!8@8YsY)*>lZdsX=i_ERpH2At_*jOO*BX_+k>V1mJ!MwBH+_~djU_3Gx-#njn zz)YBo%q_;(4!+IIk+uFX($1+cEmdYdNZ1eUj=P7VyvXGWa^~?ewy##Vzn2f;~9Jy?SO*)x-s`L)yQs_*9tL zUk>I8$D-#oT)V{DA%x@eF6;5p(Mm#5l6=wQJ=KG`Xjy+EURwci^6^mZEbT#^S zw-gi}g|0qs%_u2665W>ptxt{5J@%r|ssX+`s2cI+YL|1u5zWd=?0MFa>6wkA)58C zE#M3c8aH8!(6 z{V+l_-@YsT^IfekLNxgv4V|~jhY`)}j=ZAM5k#|?RUz9l`=P7X!(-o1C@s(sFmRx< z(m(Z`e+DZK|6>>=fOJgV6Toj72vDEIL}dzO3mD0cS71=I_w_XF2=O;{@lT&#fu!5*I;Q7ZG~t%2Z5Vj@xt zU~ra9oH%u&F;W6U0TABG$FxX}sc-j4?l4CoP8-IcszV&kQeFZmV%ll~dP9IdKi^i< zCT538lql>A$wZ}8ioD6G-sh;Vr{3AZeI`>rVLqBA&d8B>J zM5S+PR3WVS9w>(CInHqOVbV67p-fu4C}JcPK`7_G>Xb>?^;)@+sKGfAuF74nj(ez_ z$C}Gj6;dW}GG8w9#(BL8b?9@*smR_Vkrx!}FmXZ69*z&LQKF-V3p)2md?sm6S!=7c z?8F1?p?#&}xNtuI%rE2x@e5QV9p^{?iP;*}`PQyn%&7}izgw?Gx^k(-`rEY5XhO1h-|dgGznco0y2Hze?RAJ4aePZ@em{-Uh1?N~ zcvEHm@RX? z3$lGB@&XH87A~mq<=n`zPv|sbQ!pSw~s2na68>`xInazFE5y@YlI8L`>65@pRYR^7f3-W8DIDU zT?<@5@uSKw+(p+07Z84YdBHZ_dANY>N4+x@;lfPy+ZwxTrv8j_migil(V;f;`V^^aT`?^p+Kgw~eL z0^V^Y81-;;Mr!MQm|$_H;tsxli?hEj74++M{|37X)qX#0C5w)wOnH_oKMnG1|Pa_2Olkvth#=Cg_@x+cz29zUS{E zf^LNeCg_|>#p|P8<1j(@%*s|f;dwG9=su2iTtX6bp)*sIyut0u{uT&2PW?#WCv%iv znCrH)TnLdC?H{>>z#$P1ws z!+V@|3 zILwPdDhJX^A|m!2ZHxg3R4g;#Q4QA5wuas0j3cbPD^zlkNohk%$rTa0d9C0R*(Cn~ zep@8DkKYze4LKuPK!Ud-UwB(2MyI(>i$SNkW59I>Ejawze3E|KL+?e|bI|%7<%vo0 zGuvzCz)SKMX!A5G#_X+kwB!Y>>to8WuDzL9@rAmKcEj8Fu^bKGPh!W8`|`<;?0p4} zlICXbMdC$pGGf#|Q0ZVgsDRMQJ{H6unCXrZ7|C9tgP?~G+ncBa6&+$J$R9b5@`LLd z&f6^FGRJt&{F*xodgM4Ni8D(3Mieo*#pa2ln>KJy%NweSn4Eg&o&Z@ydn%P zD04cwo@fG&56^2BXFIV+)8G(q#&ogy*M554gE325Cq zWd}^_Zn4(e7;6?{erqqSJ3_f)2qU%}yxE01Ghp$~s75kBM z!<(jDujamM@<9)9ERX0T&c>M9)l~fD+o7SiLexk<&|zS1Sz=E6KlE%SK?L> zp}3OBQpirnXUby;(r~2)YkDfK#H%1uaV289P%N)hzz}5knTfZbWlsn;+$0W?;}g|@5+Xhk*GedoS1MsB$p^2`ntlpb z>Z~BxxRTULNXW;bhM^=MyoEFAXK^LwLxPJdaczVmd8J1TCHdg3m`T5cE3qCDJX}d? zBP8VG@R*?_AG~!l>B+c~<3mD(D{*HDMe<5d7)lv<#5c~QXFf}#oOTu*)$?3tSZ}&h z-_clGu+Fglg!npf2$FCmznSzixRUfC;e;!R<_LN6N-r3KWLzm=CjA1gM5!kH-BXm7 z?NsX=`PLSTz7lRU@T9w|!FK1w(Z(tU6Er+`!yYgN_5kkjZ%2e7ARUCw{hwY2OV5H1 z0$+P0naFQO%`fczo+)yH1Bwpm7ieIyV090B3+6CZrwF#dt|-n#Ittobj3J+2r#fS& z+7=g(IT)~ZQ_MlONPD7W#yc-d4yaPhPy!cQSCXom+%${+1~M_M9d&VAUQ_1F)))(|48m7J=KH> zu0&ZZ6w50yH3YYDrNuVh6O!=1mfFyxyU8YqUkJx%Q7uUoU-}>jHggRU1qDQFDUJL%XZ8Fl zO$PajI3@T7+9*0cBJd$5HVZGnBJ_l%pzoewgcqTYhM#-Bah)lw0I}*Oya>x@HK)N> zYYHqPs;(!w2wApIr70jdc?lY)ZN5)7euuJI_(xd=hFTo0)#FGMU%lx|&tD!gVn~ zmQu{9nJMPFpB3xE*It6|Yo4x|o7Kq7%;~8h^kx%&CVG@-2|^!^@Y_&T5=|n~IW^xb zJy{2Og2N$TNd_toQ`Vo*CJnghoX&5^I8N#c+SjTFLCBNibP+-{t$2$@V(1!KhfnE0 z)Z*QSpZEiU#?e7bM{T;=jk@^^suuk0&=Yy&d;(XRxv%zfm3p==Ox=^5lL z;iRE7@d|lMY@bCB!pdTm4tM_E&2+>Vjf#2bXy)bEgOOE(DFOtNj| z%^Cs@ZrgVjoqmnHpqgI^6I_Y=olwlfm9}UIObkYplR6(jVAQaP61{?Ll5p35xT|d< z3uWWf^cV2+#W4qFhd#;&^sDZ`PmVv!B-9Z25u!GVMF}adL%)OOjD}U{Lxha1!sirA zW!I7Qz_k=622FT1;${@)G33famcnf#Mp8@Vq@Wqtr~r)4@{C~ecNnpJ;D=GDkN6HG zsrz#HMx-$6W)oSk!iv$d9bbu;t#`e+7?y1pTDJRb$*U1lJe5DdJy8jZP`}1_{K!4Q z6yDHqd-?c^d^y*V%UNutyqO%7<_bdBjB5qIV^?TE{l_klshHOI7|zTcPcz%VKPcSK zeEbrGbEa@%P?Q&rav2Yo@GTk;KX15%;dMi-7JA>m>a27S{V?<+H2hMB zj%Sj>rKDaMb5~VkL%I&Ri%808{6qRbyJHOe#$M)ybB>uCut5cldRUEoX1b@@v9NK>}H(>bv``4 zmCYl%K=K=^Xsn77vd8|a15)?TZ(B%Ep{r4y&~a>zN)K$pQVlcuM@UkLEPG-EJUHBk z4l_;Zy&yr04qKcMh=7E)-?i4Gew~%qfba>!C)SZBcB31epjn}J?W+_GVR7SYrwbS z%x?p}b;}y?Ekg_VnENw)42(k`dn9T9K2R()(#aC+3cj@(37@-hocwh2>U4&~F~d*5 zZ1r4o`dhr5gKGu@2(C{8X{}#52rPPDRSO@!(#Bfa5tJmhc4`c}XaB0W*L)=z5*GR4 zO*|a%Fzn}3d+)!ieGDs~ z#7x7C;qY*RMz@0GlFTgEGapxY+I_^41t;%5>@CYa=uv>L>uH_@%O9=LbRE!~;aC=A z`!I{Sc98{=Z-Ta5o8DYt^?V4}{utsGt&C4fm&qojC-g^iro;D58)BfRUguS*+_R(9 zhOf|7pdPgKY2No=q9nMX=Oy`gL3WPfd~m;L2sn1~JrzF7DkM%L=oO9-nx-zwo7p_Bc22A}L7;~gD*HzNUX zQdU#eb)!>;10xH2!IIEFzq-R#`=ZgNbm3x+YmAY>jI8T-Sq7b7FwuC`jNu%Tr$>z* z9?!-}UkU!!ap;ZF?|NjTR>lwEhEbd}U+&({p=C2Q*D6~KjWfPGK~s6k%Fx;Y#lTDl z-*4Rhy^WGF=vM8A`*MWQ+h{u+FG(3a3K4VNbDSJ0LgsU?Y#t>@>?c$OQRbV57J%t_ z|IU@HMdG0;h8FAyfnE@$r=AP4Q~`z-vZ%>KFCh}1@-?(D;c@8CXFFw6By4zUpZwKB z8U43xwyfn3d$Nhc7lKh-q^uF{Q!d0#N(yB*k7hZsRe0W+lHsUKN*6=NNELGCur5g; zuNUPvhR!%Kp^~YOuE(u{3kV@~6I5i=%dJPEKG)Dj`nA{E2{J>hqKzKoLQzT-J%F>Q zuvgGVt83S4kn9*BKC1`d(xco^k2Sn34>m6!)ZN_@~^@|e$mSNeK+ z)MbL<9zC(yna9?<5G*M`DB|shvit{oJvk_1be}9@G^#;<6x!NOgfQQlz3i%q&{Sm_-v!q>CLk%q$XdeMVgcW11uisz26CxmjuY(Wzds zpWtPCwoO~&q}~+ktMEF!q@|?4BGcTV1HJk0-*zD^Bkne^QBe)sS3Q|{^z2@Edw%8R zmiI$N@w4B8d3sy2>DGn=sr>*SIpwy#?e1zr7=kF;%OTx1{4&BnqK@|G-r87gk7&J| zt6f)%lBY8OzvEmj^&j28x2+0@3-9ROnDX2Oi8m2fV7&4=*RH~<7+UX``_%J(d6qte zL-NNb_coNIn-u~=A?};rU*3PQ;}0+wFNhEBx*zQK49xAu$8Ik_cgz9IsVCa)+(0b$ z6fkMLAKrv6#%?%N^*RdWtbVtkV&26!E?`b_ zi-<*OalZb+pv7ksV$#AoN;03eI3pH?ThUfmn`=SLyTHoIEPT*mTngqNPvn>fR^-ZQ=!daG)ey*{Z=wsM87b#?b+3sIG{zi^@9ejcEy8;7D>8V+3w((I1O<${ifGKD%^L;3yqn@W+e|PU=JfcM$vp&8}PG6Oa zXpy#E4gDpfb3}+12|O7m^$%&1(IP9Vr(X)+2DAu!6e$QM*)+=Q|E?g|ox|-F3!a5H z8^KToyp4CSye`@k41C=A_nDqgc9z$}s^DF};+}5id~X_pQ;r?0&;QhQ6)mEKhLpqp zU4aOf;`fAe!-4#`hiH>ndj$3h2B0&|uRm3=>cigI0zM-BR5E2ei0p*}{}fT~wdwNY zf^!ogCN|vOHr8M5DGH%&o3W`SOXHrs9+)QvqYbX3K`mPmytO0^sMGp?SvkD}G1%=1fGqlaYZl$*hzLr-mcIO1qM zn=JkN(iWVCnCz%a9Su3e__6@FunR4g*Ug^?`@Mq;yEd$891F{pj2uQ%eWIsxj=rB$ zRMZO#D#m(z=IpoU@=zGGb!3zM?_Kl11YVir^vVy(jb)(wUYMvyDkhWJ(g&G-WSzEh^vX#ZRjWr~M9t*m3D}XGV^1HC!wIYFO^-qC;ut z5ayL}V|@L*FC(+S{Klbx$h4Ko)@YAPimbBRQ)Oq4MnW!VjU(%cWi_0CuVT%-N2fra zB3^$V-fi;aD%ylER}Nh*d@*x0+Q^3DF7JJGv@IO&Zc2umcI-i$rAD{EF8qD{H1%I?G zn_&;yZ`AVBH%7O1F1rLQ7^UJ$;H4?4Li7c~voi>0y=%964(5BRx;~T!yRShgd}!0) z{C1lMZ@@hJsrg`;&Fc5Z!2Hgu;;Px>?`5KBeLKTfv#RO?{b4hZoNuw~&&)cU1;f{Q z>e{B{m($N>gLzxa!=e1Wt>*8*+%iA>5C7`KSoHN27B;)>-YdHyv`tpGd4=?Qb{C`8 z2X`m$+FCIc?H#XMN?aI&xdXt0Nzw%+Cc$YN8{zKEF3e8PnERUmNkrxw_b2Wed?C{@ zu$X@~@m(PM=T&KsGV>dQVPbTCEQ((+daR_00)yY`+Mte>@T0K0njP z32;MxW@K_ixrs+Qd?SeSj`Te-59>l&0L^`U9eLxQZqI^iPc_MTbD}oQsU6JcSNM9I z>Ma+dUG>dmgJ9x9^&pt1-powCaw1by0p>Rf8e9VE8$AIxn|@3tg>yTlZ?3MMokU$Y$p**1)TSr_dLOd(O_P`l~OH%pxzZ&$1*vEfI}O@nFT! z?1HD+NcJGWJTkAj+B_82A-}5ZWzmx!P?Q9iylCIWra!|fknBOt)#tJN~51+Xi+_qbD9Y%&*&2q2)WYCJbdcku+(a}TaW@lM#Q#rcg(#sX$2}t&U zs-JEe|LD)27U-(u(`}5~d96Yudk{2q^zN-ojn~lq-eY=isPav~NhEvF9`Cc|eeta_ zSTjU$fZD5;V=s{GK~Tz??#`^+g>S(;i!R80o6{GHWDib8Up34qY(?Uk?6{$ogCAq> zy#@0-mM!n?4rLl6*@NQD?twK`r7K`pjO8m*qy2JBmqnDC#)=l>hV@LDi1hSUHi(tjTpK)3mt(7LUm z8~|9?A~TgGOwwb*-r1`uC7DCcc~zwV(ti=-%)$kd^QtPiK;-O*3plZ=G+dDC>@EFD zHG(8SdNx9wbn{m;D%j&lX9qv-ZAjDv(s|;g!{{YFpp;$IA8>(>&Ny5^?BdG{_URbm z0`@LyG_G+#XEH9}?&8Y}4(eFo0^Tm_X#-QGZPS2B?1-3 zXf$|tH6JHUSj!I}9bKH%w_nM7l&C|;by@Bbc|np66BjtTbL2fr*3rWS-tH24L8=ac z3!>aP@*bt@nBanJcZs~5Io%bNNsDk2eB}i;>;$wOh zQYoWotT$p#ng>SV0P6_HFDCQWNPfGz@JG5iUPL9T3r~=v-@J(l_|dR?v0q0~Z1IC( zn94Nr#+7>Uy&;@-*ltLNfeR?UR3BVmtUC@D5WakQ!6aQHT)_67CWs*kRn7y|6Zl$U zHj8k9PcLXx>?EeqJJGg^f4g4O7+pbRSQA__O-TS5t6G8!9$4@}mXChX$Xok4%*Zf)7PJut z`N9$wWxtMWke%=#^F9dX&sebm4>Ic(Z%t6*k9d$Bt(w0e2bulzS@r!(V2~FeCTuJN z(=_XRU%Q^AFc3}i&AddpUVzvOOD|H6h_nVs5RucST86R0`&(GJK+>jq2N%?Aal{3j zcdB=BLERSb(25c0jQRlU1yI|Mub`Nb_ep^+T}0kMq9YP{L6I&K7l@B=sJO;0UA;`Q zhV+PBdXeW$zB6U*T&R{*MoKU4nyT{cr%3Sg_n~Nb7Xs2I@%6xYzeUfM3$Ayk6k~ZsVPhSa8M<&+ zM6-ae?o>p*iZ})%XD4WT^lHkku9T;n$|iLosi`l?eNhts$3j)B%W#iWprvYX1)+@3 z5j6{via(q2oDJyiGL$hJ{z*;XX9LP8sX`0ZFdw6gtb;GTt#VX6(Yp!{w!vOU>U}IE zQmzUb7e_(k;!l~zMfrIw7rK!}Uv@Y3qCUA*NL7WR?LjT)@v?Pg(N``d-=D?Y{7G;P zAX3T>fJk3zNU;o~w`s_NooZ4Z+=vWg&lFOI(ec(RlSPs6 zE;2O^v;%nO&mipp9z!m}$Yls6bUz_4RN?nm}gocn0$D}a|$9Nh-nG$1y z!ZE3akfZJg=#s)Q_84ITY;MDXi*H!;!c1Eyc+t&0P{uuLJmxh{ACrBXcYpD&vuq@u zXgr2#Yr?A0hxdnDrosEeH$1r*<9tT&dFME1C42Nz(2h;&Eyfbc{4wd4kU}FLxUjiU zc=X|zAdp}bBMC+?=1Y)Z?1EvVPFN%c0Y}7|IfgR3ATng-o8V#2Ox-UiW{MPHqIJrP z#s(YuChdaUAr75FQPuXK?PtY^y-OI&a3CjXS%D4tdYx-GOW%`OM$_1-KS&uTa6O%^ zIodEA6O#N2{I+q%oJuVo|i37~#lIUzrG0%H#2?KW8T z36AO@5<0^H#lEQj#9Ysx?X;M3AnD|WsM&MLxQu-$e{AYW{YvWyB~yGWI!q&P#ShON z_bSi%jt^fPvJH3C_ODcuJHLRp(#znjwDD4^7NBuLJ=K-tR`|!KQUgX_M&A)l6^h<` zg~{7}3?-KUau#c5C_pabh5wNLNHkt}OYOr>@eqC4M}Iv>95$6ZNL5kx1ff*3{Gc_N1&Dr& z*gwKXj9EPP=;9Q47NEfy^ra!`I$GrINK6)r_6MnFrcNfi3`Ao@ty(!GjG<_^2OD&I zr@DMKvWgw8?<`(!Pw=H8T#~n7}ss9_T~*iI=n^h(6?KZ!k`UPYU!Q4Zhs7N3B zDbk112MY8dAd3q0As~x#fDVvFIekb`qz?gEWXo|bKo;cy9UzNx`jAkh4*^*e%PRr0 zC@%Iepk! zp+Fx>fpe4(9w3YI!2@JbP9FlYD5nnrSyZ480a=s}9w3Wy`Vf#sk$Z|@1;DXD4SDQM z0TJBe-_cii?44$ojk7m7Jrw&1Fsa9VNHFDi-z*_l4iY|PD9J&>Ei>t1xDxLn!NQe@ z*#PEXu=F>Ek{l!iWHB075Q{aV7CX!W<WKo)SQJpwH zr88tCGm^5Qh7|BY-ZWrvqEtzCkHYY_iAn)Qo=XU_(s3p4VOGsAyi0(E9YP3u28j@) zdDUDUm5nm;Yp;u_k++?bdK|v~lHlu)KLNvIGdh2f9tex)V#rfEldcj#Qc0oJ;M*s; z<6ef6AeF$f$$gB5X>um8kD(+73EO7UX&5BTuO4>RdwaHYCx z@v$h(1WK!kWta&R0l+9{0*7Fez)axaOnMBiMEQkq!<7gC80D2TH3TWR5+I8rTnUiH zsPYjw?v8k+^7}Mte3D>k7{!$Q#&rY&Yc^3gO^@OlF9X_EW1y`jk;7iM{M};=v{@Ob zOqsHN4F=knmpR;Z_&1>iZMEt)0NSW40BBR5{(2z^v`Kw``BbU(8osP2Nd>;6&ny^q zS@7ed(FWP-)Um1$xLE0Ye?)wj~8Qe72q*~$b$t&w+> zLkAR|xCtxM#ppg<*KqzodNsv7!&%7}{oeA4Quav|U&RoX6@swp61Wt%5E+D(0utxL z`J-dTx7Z+_M*SZDEHVZN>&{ttGVwzY)*sw47{YpCx6|fp z?idtvF-F~vX)vD5MJViLsNj5z>xD(?Z6xeR`BuozCb>P~dJTajll;xF9$vN-k++o- zz<#(A?^~f*UTK4dU{c=5ZwE&AOvf|i@0ag1I?eLQwv_==e2Lq3okb7EZAbk|5bfjz zaW@D>@=6;ulrG;OFUWlsJ<*K>`ldomgltJYXHQd32)fls#u>*mf<$IVYGHOtV0K;_ z)fmH%3Yf#Yb4T4n#NTTpvDnB=*%E#*x0u61dBh=sJ#S$)1AfvB1JKtQVOlG~zZ%9x z$JAIeN0zF$@H*oZX>@}TrL_(l)q#<-(gB$KGyBCg=qKYN>Y@;5!LgC8kY!RYErn$} zgjU!=F2)C?dxcCxPdb|cSMxl9yc!DNgG=}38`EdNJ>ea45JS~>W{Zx0aZSFQqsZk9 zaRc7u)4X`PaC55o4m{AXP<>vAKONH4=_mUN72)1!TChDBbSMWAZ8;xtG^aHzz~pNBExLz&S5H$-pi8$Lx`%_B#aR_ ztly+!{dEl>g9rx9_7#sUbEwqn0R~ezCzJR9mJ7gAdRtwv03M|7z}~zlofx=n!C3}} z(}UENmQ`wnqQfa1BIu5{^22M+Afu}_O_MPPLa0@97?am(NeVXJev+;_<7=+g!3~ga zoT~*yB1dcC@=C1>sA^2mBrW|QWsp#(w8Ee7{yBqqLHjcXu^Kao^_W2nRbUV=0&Td) z&EE?<#zZjK63=a$Elj1_Oc-bmI`KSojb=mKweV-Z0OOgb;6NPD6O$ zhadhCKuiJv(Qp6 V(7TPB!4so;-E42P?mOm9nTZ?2j)E&|dRLRJBX-#7I`qxpP* zB`?O!M~IG-%EX23&PxV5@CEg=YPfyYiC~=!4(_SgVsDPP&G4(imWT%8gK9y zAbDRen`9e30x0s-0S}zrQQPewHPUClLdZAi5j9E`@rysuhm!{!Xs<=tw4cPXdk%OQ zT0rlMKUtZaj(m)IaqP%5GC3tAFj&2I*`3!2`FmG}vMHd12Ol1+{S41*ob<-uTXs6v zlOy2cP0!HL_YksIJF4-2qAKB(uBG&cB;e=ZaZPQ>>q!B}mOTV`u+%mfrOwPoha5c0 z)JFLd(CF+q$HfkfFLuiQdUaOzSJPMD;4^)_>^re?-AO9Gb#@e?h{|zL$>n^_Pe?gr> z_KD;FgAGNdA&Byek8pe|NBBgxnv73u9_ACpGComyK*lF#%J@WEl1~(oe4<#!CmtN( z6U#hheBz#AKCw#9Cx%qW_`bebGCr~HZ+xPOj8CK%$ZYP&_(Tfui4&QW9)`~35E&6! zq7OtQ1;huNA0i^M82Cic6MV>}0>tM*gPfe%@R>^Iz-M3;H_UY6=^?7IGg=~$<*e9j z?F6V`O{4^6oFFF1I+d@7Ajfs zf^iT1!(oG0vaIo`KQ^A+h z-pI5RS$RlH(O(8I3M*s)<3<_4DE37FWAKj%gp@&xbqHFlL(rllf))c{nq~jF5EM${ z6r{CiKMrw=@$d%-n*VZ(|K%3{%Po@J;xC)N`Y*TmUv80!xka=f{>v@?ms|WVxASs5OCNlS1L0cY_mRP03&MRRWfCCV_q-F8FMPDS7k&cs$L(!*56ypKgUH_lx86*2s~U<&3C}*2n_69;*P4hDp5^se49=f^Z$A*I z&kw)2opmJlAR>PouT*fe%5ut)(CDE!`^seBL39UoR?HYb*p;6r1@ox3x>LS&txbsh z-R69jd#l5x0}<%UPP-T6%w2<6d~wOK-MJqk!x0mGrYo$%V?wPrQX`%d8=PyGk-Y`T zUrB1t)riPhn@fQFWw!U(HKh!O1UNu&2e&_RL50u+g^uhp5@!D{Om2qI5zgrM+<+&q z3{mJvcFyEgC2eX&17M!ibfi7MrJIX{MthDuY3z+FJc(HJ{Eu%RKD!!+LPw$^r><)G z>qxyX@U747yjnzUk(W{E$cG2_5&{ytFCwAQHRlfW#V0R7q_+cteqN*0dfX(J+>vIy=y|GDRecp`(lot$FQdAj-FRG zUl!UugIT5O$7dWH9QwTkW;N=r&(+*{^UocFX}ZvKvB>R2zxkE1+f;m)x40bZ_pd`~ z?3TwDxD6G1qBM5A)|BF~jE1xQU>?OXD@@tkC`)74bE0LwLG8OkXjTnh>@}*o7LIss z>D}a8!7gS^-Y`vV?ftclhJvGLR_&d0cNtvoJq)wTF-S50veDij%_{4KbWE9LYc$dj zZceMbzRNaiKblqZfwDh)ORbinS@n3IS#qu;>k*pOq87ba&n=ht-v^Cp;%e-%QZx~~<QC_l59$&n9$x+%r8=lCtH=OIS1a_pVIbU3GU5(c;GfORi*8bw5C0 z{Df!jM$W$Cs#e$ME6@rzY+dya)6A#*?QHe0i@4nIB^G zn^lJ$u|=(Wb`_e2r;BfM0qv$MVJF~^iL^Am7tjz6I{_u^U0r0`p7<7I z-uXH~FuL%Xxm< zHUsUE%exBAOaACVcT-f^o6UuRhQd-XfBq_U@aon24%yuVfYz!vrdMDQNYpbD3zM6h zQBug72IDcYS31&BbVP(nOJU>uMwyU**7yCh&L@9#M@b{84m-&C$%s5mnO%QFKIDIlArjb@zc{5wU+_k;6 zG&L(H9A$5?N_SsX(HPf?_Qm!y>Eb_>nNhJv!x}P!?x{^5E!_D%tQqRELTk06qgT z6|#*Z_t5Q+W~XHvhot01`{2mN(OG7JP6``GC8cJa#(x?M@gIgv{O_6#X znyjaC3RgPAWzNHuxB?F4#A?+aDz%y=-a;D~LZOe!tdpzPv+T)Nh1mj0vkI;t7H|lB z7S3pzLNBN5sRZG(6myy5a3zY6!52=}7FQAnb2>v;t4397Nl7xEV^2#2zBm*FP8Gs4vQJUp0*rha!N`?bB)NQJ z8`mDdk++Ckt{s6d3ECUI9?6fZqD_@eeP66+bhFYxp0^}&K z!6%h!z>sS;PqLM1vD$zZWke^O_U0F06BNpq)|WMN_l!J@_>!k+lVpRcTTOZS*Cd%OU&Lv600DEUle5 zz{5vB6RC>r%`=R$ZtUyN$H zZ9*(OeA3zr3es@jPIZ$l!a1-AW3n~osLUM)i}0c@sfJqP!u!2i9=XB8Qm9wBOc?E5 zAn1GXXEfSoA$YXC9}tA`Wj{B}*v2pfT3YELc(kJ)5G-SIw0YfU)$cBY(H>|2dGLnr zzJDIP!42aR0BqKYdG zY-j4>N|du4?<>EIgmCc9X<_6&MV#fUtD2H+v(9peB(iO)GffhO+uotCvKY4=vx7Mf zSMol~;XcQeHqzvQ9I0DJq9HtM0v4)S-XS6GKc}gD`xC;YE~82L7l9n-8Gwy=ulQRa zN6iE>ki$E0iP{|8WdF304NOYwIY?40Z4=g&jUn>^NQVB0Ko03H=GPz$XMxtNLLi4I zUSa|3=Q_ykDh6^;18Fy5*;Z=7;!_Od_)EX0?3Z=&?Gp!b@YQKp&^JoJnZBO)H7s5! zs-G$`KrwE%aApZ84POy#rtKzDC9=pP5Fj2EL@P6Zee^XNEe|QglJJqdRSahV?ImS} zVk`&8>;V}=XGxP}Q#B_@D2_*sB1RirxnoK$TOnv^Zmp?3W|7msWEO2F=VT!)1yl_>#>@( zPJLL@H{lG%mSREEm!Sm=m^dD3`kvcv3L3mWtWeP9?H<Rfxo23@_mRRe0Ow?y zzV8)|DfbU+`i}crp}m1&P2X7x$CL*ZG<}<^k*4qdm2iv3B28c0D6HvAS*$#9Q<}4q z9ij$p*;0|1djYsDaT;)2G0OJ9eGLbMEVI$`KXD0QEELi5hVlcTTqv-^O~6pGAJ+7R zxQ%wC>02%S55qZnP7H^0n5FzJoa4>~GMppm=06GNi20dVMKZm6XE>gIN#k6s0oZOV z+t_6G+G#TmtVkKzo%x1RSJ-HV1J&AGS^q;s2XB`Y!dZoz&qExM$(e<)&1P#GR9c5F zG*$Wb&KEkG?D1Vm+V^Wu4a-cA7j5IgS}vOh=C@6vfJT*tLwWK}@yUjXP$D6L4H7G-Tfo|pKQ@DH-jh#;}W4c7Ga?!yX! zZM-SM8ynNK@E6RoUkR>h`rfbcjrgsG(m5@&{KRM+h_!hZi)&^wiRY}e-nR4qF!mjA zOaR*B+Dr&%mS}+cV3o}@Cky^F%=Zd3>78EO5 z96$oJ;zmJ1L5aAi6BPwT;eT$D=eeQp@BR1fhqn1X&t2!9bH{VeJ?Ac|vN@6BY6?aP zM>ad)IE#HPjD8)gU0kDI$B+{#*>C@MzmERz4SpR5vd#QDJ{_g}IzAoE)E}XoO+?|U zecW|a=Sjm=q6WJ2R={c1BUI;IhSRG1s2&$o^4#f6`f!yNXIcf3VChV^h7u&5Nm9uR zrZY)rt3+?<0^rIaMcivcE!k*5=yUP>NnKbiN zJlqdI*|BivP4Oih=}e-#JSm-tw?L)AnReK*aOX|;C3?`Aq<498I+JFBiii8*XFC?| zyg9x^UpkZgE>E-W812`Qc#QJvNc8Y%K1u~#&Ah8W!lLu!+~tvUMf!y*Dei~ecB~?r zODlW{E}g0DE>A#bk}gu|ai+a?tP(oY8egIlovGn2PeNzXFH%WyKjhf4{-QG#`4Zje zOk_1rMrV?KuhPFg+TO3D!R{R#A1)mN%ZRyON6E0NQ*h*CCn=i<<=HH~MvmhEjeV=u0F`>u#EHs;bTWItuLC`E?YUA2e4LnjbV54Kq7vPTV&;Xf9f1cFzYfjq_I@2Z@@T&f*?Doky}dU&qVCupiC(hhK+$lNg1IK8qZ9^E5SvaMAKRC|tDi zv?gixPdHpO`AAv|enIPj|0%P}XqOE6W>q4bZU2Zu+r~xEXWK1?*tbUK@tcQ%a5h-Z zq5VC|gv42J<$VmosdHLG`+Gy}WTLyW0(xM=+p zi%ix}Xe)>cWx5ohaM9uR5H7k1g^La!*=)&lF?cLN9AF`DE?ESQU+S~aE@#dPXqPcu z^c`PqO|&s$aD>Zak&8#(Fp3N0-S*+4yIiHYASn{(EE{{NhUP*)7xtIF4%cWdsEfp9 z%P21BjzPHS;c!-5R|s4P^vj^bMN79S=8ZkYeu}v8M5DIYe^WeL!PiC#<+dBs##?* zhT3M7KIA=Ssj0Hd9`Z<0jHzsCs!GZ@OKop@*t1xVsRpeP5?ne{!$Y1RoXVt6Rq-}b z`^loACTY;xl@28*@cbwjRZ*pgIL7)@1hY&&ihjIOlnM@~5F@$U8yAWLT`s z^m+oD=Z5pZz+evLNTjv9GZc0;MklnIq3#o8bg616m3s4FsF3>kc&KQCnog7hi56N{ zfEPpva_ImM)nJaT>q6^~CV|uH_gGYbm`ge!7Y+q2=Bfny&}Di2;8D;TS!x_D=1={k z`7j56Kw|FtfD-e;(r0qVm=8EKzsMAiT5A{Jt4*eyPSkcQ_A4tvqV*_0vqv40=oqNk zx(-RS)<1#6>|v1T)kvb_$vTkes(ffd24@W|(e=ZK)bqRgXn>fp8Jt_{RIX?$Mf`M_ zH0{z>aINH*|3+cxcdVtu(C4in)q6wE52IX@m|h(5-`2O+j-iwZr7Oscz0~eobV@f( z*9@Q5@-+~KUee(sd=(fjh!I{`sMnTzDNs1%fKl#^qIatVUYQpe>@HRFG-u*X5u+Ps zoBYcb<`hB1=#%LR_>BMHWL0jL^Ya>A49P7!sDWz`jlLG@l>#uWk>pLYXS-Jldc!lv zTV~HPR|*91%%$AyS?)@~NO;z(!t9xLr63rdx!*Q>c6Fs7lnQ~w9K{h=c5(U%waL>N z0w0)rcgSuqfBZozm?%fCE3yu0qa2*;_d>!!C}E&~J?#dPb3-Rs`(oHc4N9o)K!5|s zOK{)_?EIWpA9#8U5BYY0iwC@X!}-E{S?r}iJ{<{@a{%&GxCDr>ER*9DB zE>Joo?JAOZ?`zT_k>z5+)oPOtNim7T9K4LuAq{b5|4C;mTSYq4nIz?68O}7B!}MOZ z!~HNu;JuI*rw8vE4bwxltg1Ic7(~58wx=DP};lg0Q+i zN{@FPHW3>8ujSmvWzSsc&G_%!4r}rRD;o4tn z6D@iP{v2unDxEywR$wl>1RQHPH#<2`+|Jl2t{^uuO&`t9yW>IqWkp{<_c)d$4mnfZ zULQ-cDS_0@;V}PsF{W-V$IBVSlS(g-BMFR>;z(PJ(&9)Gqs?)o14imN(ix-eaioi< zbLjSn7hdxOAB2J%aSoSwEjDEm*&yIHIh$2tE97U$L&wr(VEzf)(z&gel@4H*__sf zT60>{%xUHKFrzgs!HiZxojI*PnA6&)hZ(II31+m)>da}~XijVY9%i&=CQ!6y4z5vK zC4P6+yM{f5{i{`?Q@F+@pLAA0_jd?Qa=*s;7nUNOw7I!J=l}!ObrR7kj_;0NZf!k zrE{1S@hIM%tebI)WLTZ=QMLdix@-Xh6uTpS>`WU~PH^J7g2FepE12err_=|T7H?_8 z3+5b{>=m=Ez=V?1c_dh#4Qf1R{Tcx`{%0{(JAWym3??HuN2s-5d!6YKn)5SS9@nCx z1E)5_!Xw(e5b9_94~BqO%^nR&Wu1}bL^l=Uxs**I{2_$f?u8^Y@K>gzsex&_JWh^8 z(-q&qS+SZVG;sN%aow+4Z8in&fhjN)2C>nxTD+I@hf~!9YHrxeQApQP<8*Bgw(1o% zP8(K}P=%&&8Sws$N|tM>mZw z;!qFvBs$afHKY%nN%~x@$C-|Dbl=dKGJCM`g`Amd;X=-|g}g3;X(dr`V}?=ZslO+N zC}0<6!L@oZhA3bUY-v?r+6SKS8n1t?Ev<#m_`nlg6S+6o(&nVh54_-O^oo`DxAv%Z zqy;GArwuk`9OI)n5o8Yj12QLbTwszprgPf*oP9c3 zpk^NvOc=Y9XP68y46FG)8&MT64OJ{sP*tRfZ*Z=>T~vneI$TQGa(I$_gwu_N^zEIxJsdh%Th2R0q80QoH4?mfvI8(0(32gA(ZHVUAdCj}WRJl} z(33qDqrpAdoOEhsx7yId{W7&!UgflOWq+o}YuP%Iz^Gy!X^T^PT)DscM1!?+DZMFvE>j&kaGYT+YWR?bMNd5dK~1m zGLNE-cBWw*IK|Q9Aa{fG7>onaF*FWpEa#ZU0i2}hSyr@TiJ82wr5cq}%N$5qVCv)5 zHeC)a(ZCQphKg1@tpRUu@+7KXF$b)HtjRVB{KtPHYhVsEn6t`BIR_eex*L4ASj=sN z21-%HxrO@IVqwW*_%wk;TeMDxPC9j|Urg`Jf9=68wqk#_v=G(D3kF8z{O2~0GBa3I zbYP9P+kEV{lzwu%gL1T8(z)9j`pKOREQuX`&!! z7zIlRD~u*c2m+%?5`v4-6bWIAkwik+V>C@dkQjwZh)x(ymk!65~4RoQ4*prMoT1wJ4WA22oH>6B!m}6G6^wo zhb^+*@Pgn*(1m?ekvsPf4R_1@dB!E4N7yCW7Pm$MKyrza2rVaBgpr<;EX9ann=HqO zXPd0SNMxI=#Yk$Ktj9=ho6HEIYtYyx^DxrdCW|oA+a^mfV%R0iG2+=JYcLYoC2KK~ z+9m5TlG`OSrs4YRl6e?u?UF?p>FtuG7%}XVbEIsToQ}FyKZZ}14WlNyhB16P&X;PggSkCO zC37hCCjRoKl9J>2Bu38T_|6#d#_@R=xsT(!VK;8A+*X6#iW17pw7mWCH35b2L}{ z#VX)enS^?x7|m6Dv7*cq%zDlg%Rw#>@K#i)gE&2pQ({cpU!Hx7YmgMSzx98?_V;m$ zvU~5HI$fqd(IQ{7f1|dRF`=PWghPO7d@Os~gEgmzJqVa$SR5r_0^wxlW)&D}QY}DN zt+)F~tIMZ-1J)(D#Pbo7;lzWKh%i*Q9ZH2D!+!~p4A(Z$GOQVDIRvU;d7~;G4%1Rq z3>Hj|IO}bXHSE%nl+HVDbOQBr)o8g#@FTRm`k?&$ktJ@RVXs4#u->BJNi$uvfkZkA zo+KR!8pQfx?6=U0=R&k7+Z)8=IrgxHVvf4y zmeD6T9m0eKc6bI%?ZzIZ)sOU;XAlD!N+KS0hZ3Vcv?WT$MyYsiw1!qFRlpM{M z(B`Bi648XrM{-JKyg~wz!jvCCE*w5+Z^1^3UtuvZeFm%5RH{H760rw{f^-iW3JK|W zC`5%uJ%gbjeTIg@gW^%tbZ*hT%L6|U!h3#bC^X4I2o2MDiGGabOkupa$oo_6E0xej zCDmYu75iDM?O|W6HWXbK=nqQmfPGMZA2oE#w}Kl?|D`O+7L#{dq%7TSwpa zXFqdsesV)cKY8DuUGG9z-V)P+!#?z9|LsChYX5k5uPzy@;A=)VoC5|XY0oD3P%;{r zO4sTiM+qj%dcyaKwC9BNWm~P_N>Y~&B8hk_XGwtz+ZaBq;R+qBTEDYs`Syp8kitsn znxtD*Jd*R$1uhbG|I1H<%}l#94Q3|w3VZh`_}(!F>GJTf9iP)Z>B(*jvp`^*xP~BN ztXvXbu?XeFrX~<&zq}(P{k|`98#6{`~-v= z^IJZ}dj7{16FBiSUY4BK> z+l`e50mRVbiAVWKrLy7;oMEnTPl;PVFI~Q@3`|MAy}@Gb7MTX`(l{GjkrwcW&67;Z zp2)J_iqhb;pTd^&r;>(?$8*It1_}wMs+~L;Q$^+j3WU8jL0JaB8+%P0wfp- zjs!?C5*-PUVI(;cATO0kez4_)Knu*wN}t_nv(i_0Fe?R2;y_esSugiap$_GWut^+& zD}D4$-_w$^EQxw2N9<&bfvMwhP=z)K)`dqwyv|+0vea2IX+Bt%rj$(PbnY7CykZ2h zEIBOx5w=AcuaIRa^w1ckhitgE7c0~4h($Kr9_*q74ka*WJ+OO(=C=S58Oxp~W?&x){clAxIf|Kd`a%!3 z5+l6_TZPd>4|Wnpbsp?wj2?NgQ!sk$!A_mY>Cz`;sVjThdsi>YGe~|sKxF-aJ#8%b zDi{U&g$Fxba@=)3rElT=96(l@4lGyc!nrxFrLmF#nz#lJ_DWaDZBWWk1-$cM$GQ?- z!b8T>k3M)f$Gfr@umuwtZt->E;vO40b|Xv{9QBES9t5Qv(#_$nkufjAY_Z;zead#= zyI#7}31vCI)*(~sASCz$i*if7n44j&IBX&YO>1@q=MsamHd6FQ(E!k@l9#-6M{S3N z8b}`!;3l5&57Mu4&l@(iK4ZA@LxX?*g|}1xuEcT@i=EzZ9SIvH@tX&<5RV zt^y-w4OfSexQ5Hr(J2)*Trozv8gPZ7-!boVbr^~7bD2f-JH>sj7$e<%t^yC0E4AZh?NYx6EuY6x(K&7-JAu}-4%N?0M?Mgdj31@iRVk{d6+NpR4G5Ng!3Y!=%i;o*#Q`x^JE8N zbl#I4gwX|0_85%vJ=tS1D)3~F!>G`+Vx_!nCADZW3Ify^kp%%7jCchBIT#5F0<;*3 z3IfV7k`x5!F_IPpG+-nv2q1S98@ZY$=Gr`!i1Rqpm~Laz*3zorR~|?eY{$_tNIMQP zO#QaL18ta+&*cl6ct;aE(1t1XT)wbLl)0e;ZJ5%`<%^nV!&J^Ux}%!MNV`xm3K^z) zGJm2BQ~Vyu^{McQCg=)&6g@-MB^3$bN}$}e8YiF7jvlTe>&70z)}~_>Oy`#;MejkR z1n-eZrH#)pi4?<&dckM_h$CfLzE+pIE#D3WYl10aV3_cd%_1@ zgCYNJ>;HhOL&;sJWTLz;ayxsj`h{4uLI-CV4qbvPSa$p%?+Y8@duV^C>AMtKVb}lE5hEBx#FL?n#owDDNccfYF7Mq%%g^lcdX9DAU*_&BrzF z)D~ccSN%Hj@j7bI2rdTjFcMu15MU&^7$Cw(dNDwPk?dlCboY9y6#2yfat!s6^4Y`0 zc)itE528=wIxThnh3`p-zdh>?JQlQ|&3fLYU9e?a278Bun7Z6cqA{vGYv_(!fz_j;i}&cq z!;t|ks-KtRFfuAl?f{1|)lk1aBCt8<1^p&F|1ale3zQ%|k{DIJiL~ zALU;4*uA0Z76r8&s81lT%K@%%_s{@WgTLD2(byupJ6Y3YajD);3w}e>oak zWcMWNo3KSTI{8;tYX|srCffAyRo;APDs7Q%ba2@XAH$#L!^d!Q(B00r{?<#w^xIg3 z2khXtd{mk4AzCt;{(deE6K?g6ng-S`tz|M?MHJj7`l138)0o#BGZRyyqm^5g9qacN zGZPczP_p4RqQ*os6O;4e2q{>tgoQoLtX7j|I`4%dLz>Nt-0o-vUUICVt!72ae@(_0 z63v{6E+jhb0EF@a`9SEG$3UnGx#vkNG91&s=OoWijz+_KwD|KOt+W-*17#AV6q8u* zptMTz!^w)T7M&axS_dqEP7zst`e_C$1ad|7nvx_rdWM+4Qy#j|Iv8Fm#F3%wkYD<^ z?*O$-%#tA9jv1_PAf~@?HZ+f4h8Z zy>U*($TTkeODULg94kf=w?@)x;iY^MBj=@jXN-7D`8%^k*d&<)eFBXb5vraDsyDH;deJ3l^CgR zIcefms7^!YKvijh6<&H{>ROODWzW7~*qY@)kg(IsUes|kZiIM8IBg-AXyW^bz)q({FxDTFag}JW4PHt z<>xr%2EOBUNV~sK1W}=9F&*Apz5qW}MaY7FsV^0IJT&F`6AP9ruCh zctMrZEK>tTwiV$Cdxno^#!(h z2Wn2_m4K$5MxqkXw9`mh0-AOj$xA@fP9seTXxeF{ElJj6q%W}y>7bT$pjsdaC$?cU zEu7enQD``^6Qk+jL)#wa|T*o)Dea6*HTG@Ll}&cx%v>Xo{69+L55 zkElawI~-_j@L{A4E_)B!;9Ktv+Tbxev9`uK2HMyYn}!IW;OfL4wHaswKd6w4q&+|Z~(KA_`3?YJVsuvDPjC6YKn)4P737XE3qK zKm)0n(^0!1kFS}`hgOs<2sdp{1gejYf;?D>|3WKq$r$^hUzHbV?#OcycY>RMI}@4= z-08CmbLYrF_WJ(LV|HQg9399`>`yG%g}L+FK=uaAodE;9x-=XH?uerhcSHl9(A?1+ z*O#poOq2`+?uZ6n&Qs@eIP*+?NBTlbP<0APl5H_GBUbE_%<}! zIw>)-nB~Okh2PC`QetEg?!!t%Wgc4>G zY>d{-Caf@8H=7_ZN}Nq_G1AT^Y^&%}ie}?mY@}JVkAgPKazDG=A;C!YOHsBZKXbx- z#N*~~&=zW?-_O#W`fyM$-w5Bm0(L>~NpCb!_p&$)GL*d?^vrR*R zwD`j*>fy{#Gm%fAyS)1*qfZ7T8mo4?mG(0?C={*;GcFE} zPH?GBmu21F%BA&l+6!bxrZhC2OW)K~dV$=X$*V}`(*93n7f4Me9r2** zOirleuo6CKJUn*83A+3WxUgD^PSDkn#mYxbIxlfPe4?=O-QNdjdb-%!yq?YwTPApB zemFv==~TsB?=BT zvsxX>H#$Of?qX)O`hJk@q@3<-&KOHvbrb8yACG}4ZV(J_>7d|z2j_-NG!JBPktu7?CzuC} zrmV(%K}BZS;UX__0qT?jUK3o{4!P_mAKIiIblXt@_Npsh%tClsH|gaj4O}hz6+V{r zlYb+=cu&-$l+9Ti%iwVvdMr6cw0ciuc*A3_JsTSa72ElHRvi^x^-kJc%6_&pp!j;H zw5wn=OWMrzKE?L)so|Xo_{0Szj>y}~fjaEoP{=z|1XlGkEbxx-?%R2L6&E7FWt7Q^ zA@$%ZGVkop4r1!Sx$cF{JWqC{CqgC@3J5!ZCE)bnj3eM1ql-D*ojC8MUn1kte%r`b-jKptj6c{Pq z*yu3Qy|H0><9y9FVvNMiHVTXs%{DrWbj>!*AvoVB8!<-WPc{mS6rXH#80kLQFn!Jh zAboN3HJK_15vlR^>|6oALV;w8*qlNDm^^WdjRGS@OFPnA+L8XH9qC`%k^ZF}>0jEB z{0>{cHU12Y4Q;@o)lt~t0za!^9*HN@?zE+hZoa1vkuq>`mZxA(0^$+U!m2o zS;FnvpnFnhBi)nWy*H(XRe0Jj@&w(}0#8s7SzIllt;w>Ds-C;=a`E-+q9K%a?83k_ z>YR-3xXnR{;vSkqGYv(14;cbS{`r4E{hh8}g(Lq}!#fU9UNs!~_Zb35{snGVugU_Y ze*{1S>nC5MmAJnPvjXz7wOq;JxNL+2}zB1%7sEUJ46ghS0juPmS&SB zbW$~B#7L78D8MKuB~XNsHYHGkQCUi$6eE2~pbVpil)xG;Q3Embj8)2nY-@~`3)x(Z zRtVWP7_AhtZ82IUWZPjBD`eYa6enbNz$jkGCNWA7n#Gx`uY)*qrP;hLKGRBSxRilL z-xL52D_I92)LiyK89F5%Ih#-GY9ubo^If*Bm>oes32|Xv=#r$Iy`O$E%|!`MsFnLC zr`_D|zk~gObXhW-63_W%7i>L{_rAy_XjaWgzrc#0)AiSdM9*aKS*hhBpOv?N zu(nIiDYuhalXZ$Y(f$b%u*XQi0hjU)X=+)Q&v+lshTe+anq_aOs!Zq3EqiFZ-H z>$16SF6K^x3+sKCq<3>McUHS7;R(f^b51+C!C`3kwjOcEx%=-lcl5Jlm*L`P-Y(#d zbN4ad*Lx+a1{*g?3uP%PI!xWmM;fqb6;SINbF{qzBk^c^9Y%`L_RJsYl)BOOVvLx< z_6m%|!S*^`QISb+tm|%12dfioEB?$?e5Qv`Z$GvYqdtCY6-Isi*hv`m^J6DtRVR^E?Iaos-__pDT;?G86e4s!>^mavN#} ze?I7OfaPj0FoUQKMdw|s25VT5JJdh10lMWkQTmu_B4#DBhH0;XHSFY;245Jf{cT{2 zNe){94K$g4>PY(tudAI(fhxBWyT#hcmP@=SPc5;lJ$Pz~_XY?UL3@h~;V-u*(Qy1LPW3gJZ?zFv1wqPPZ985Vh#$+^d=% zsgK+bUO-=ji8Bb*-&Ug>pN0`}DoN%^Gkw48T7&@f{k-qJ%=G<{E1kE2zJF798*D_h6byzUoZ9x&?2F0XX{9G1BN z4pA+GPI(wZp8us!DdbA&TQ*j^Ysn@N#M@Z7Q;J@m^=cejx@!ij0KS78QWhlT|I()( zQd+GG9jk}Oh45HHEJ!MV(^fytafDEIh0Hyg#u;SfPR1%!_A=YTt}&%QeuG|>b=;%5 zekj}WmpZiv5Mkfy)5`AY62VKQZYYI9JYW7=e5SBu2M$+^sQsb^cj)`e3Ojp ztyT+ZRn$76?J4#l7(l#(7E=zuHp2O!#Shm-e@1M!?>!t27z`O6IV0`5tXNR4gu}Yq zr#wf}JkbWunMsEq?HtO#!a}j`PI*OHp7MHWoSdq}9`?b48eZA&QW!!(#yZt@Zc6)pf5$}S_)lYU$;t7-k*OYIa7SZ(xd|z z%(%_Z-0PS`aZY`cPp+9sp^TgOyuC#?S2%C)rslHDF4uN(y*z^`ok689h1?jG-Q+uC zRCSZj!$^OV?~YOZO}+r5hMRn0>5NSGA3PTE)k~H}P@(TE2D|>6&5Ej`)J=Y_?{MUp=gZ4fHo~|(e%>N}#sbNn+t#mhlT!wOg(9O*7!oUq zYv0&~)?te@2lh+Gh3%P;EP9@6xq5#mW#&0Y*3nhg(w&;5lR1HRqZ&S4x@xD#fVZWj6~B@8J}@|(^GjE$)~4^ zFw#s`2N@9tP82uFs@ssF$rLn|Tj4sC# zYK*SN65BAk9!qS;s4SM)iP5cCA_Jp}SmHkz-H9c3V^kGO?8WF_ETO@uCYCsW(Sumx z5Jvh~;xI;au|(zuM~}?;JEb6!x2%vzisKgcb3gqLk$l-+B=vL5MUpYsTqJpO%|%i) z*IXo}bInCkKG$3%HFM2HQajgNB=vL6MUpY!TqJq(%|%i)-&`c6^UXz4J|BtXepWE} zH<~UN*YZ|zT%AC<>FNZ^MO|D}KT6NYYr?T|lXag`Zf0E5(F@J|l@LF#C>$+|S5eX0JvB(V>4WO$#UthB z_D$6tWzUduvkayai{D-O@u1xFb%NEFnG{DgaB%Sbhq0HhT2@i>q-u{VOLvu0kz4F_ z#c@mP@A7$RIAm$_U04kmN9d$Et}G^vR5`9nj4V#NvRL@tNmnIC&@z?@jh7zJ1J!!VLm^Q9PtRr8}Tl2-F&7)4d{6EKoh^W_*NRP)uB z7nzKOyg1JLqr7XDM>Rk8stBj`zr$TY|63QIEM?TkCmZy?nKrckSO1F4JF5;oYD4RP z>tB%<&hjdb+R*ynhF7HaEUo`-ex?sCdye+;D+VI{@0-Rcl>S$p8hEGxRGn{-s`HKe z`M{T+E}OZ}!^>>C9ZNgosA2m%9ZF(Q0yZ{59}5Cj6^jHc@nSq>NK$p`v2HdPk`iVk zLsISPeUzJLPP%2BkRDZT6Z5R>Epf+jH>cik?Yv+ z3`W(fi@>N_GR}P;jHCz2&sVP(8<)o@id-bV^kkcSYh-$ zo**!K8BcIAdL2*LV$={%Z2b|+GzK5M8YkIKnY_NbmdqxKW>dSYs#-pYk-nDij8T0p zpNCOHE#DoZ=32hs0bL5?AwTVaX`^cqLPf8WsiwImb-FxTZ`x6*LS2(I#{@3GmiErN7!s%_^`YJwl?BB9NUT`-mypwnhi$oN&r!MpW{hJmhQ9_gx+;T9Np?A2Y4#yUL2V*z zx`PvwK6eF!-W;&jrL^dvfu>KpxKx>THGKGNKp!WG!NrLl#9A@+6;c%k(fEjPMbTvEi>k$wy_#x78G z`yY%d&(kB>v%lLt(jExSDMYHTmSP}OhqMlv%z+((d|QjWm~Ak;i&}Y~zmeuN88m8x zkhZ0*8Rh|_@k)E8Sk%ON0iDu{I_0aX;f9wE^2ApE^eIz&K>zM&X zZQv1A1_w8X{qB}j=A|R-LVffWBb3CoPHK4U7I@rkbD5Xe8Xg}j8P2jsmLzz5M0|

v0KHuLiBq5$3cXpPI-H3^r05EDP0s^?# zI%S~(GS%~-6=1@WZd4KW@EKVaFwh}98zof-7TX;Wekl-ue!a}bGX4d!dp#~Q>0fkV z9gYe?6?s$$j$4C!2a12K%h_mgtz|Lv+%l@A%z^J^?w96{t z9p^le$_!g0cMtVa#UNoCEK8hJ7>jv)7w&PZ)RX~TUEl+%bpH&cOSk&PUxt!hM}+bH zqWC^({D&!}jMN%^vNyqDF&GJOkNIDDY96P~&A+bv=Mkh;Te?r@jzhKYY_>Lz7{l-Q z)70?N$8P@0^5$0@jjqk@T2PQ(&X-QNz9`+kayt$!@op|jp5za?r&XRCI{T zc(jk9!ixO#=0&ca4SU|)xvvQFU-3TwME7M@KG6Hk!FyUt*Vp<-|K2%i>Z8-H zZbG$zfobVApJRUPoHV!k5B^Nje@pG6r*zj$?HKu`<6ucqlJMFIG9=K;tQ%6hKq+VY zPlp_jd}oWBePX(=vV2&c7o&277w4bMg@(WHI|*Jq=x%CwLG&q?s|Fqno7$4Q+SK+p zzNUuG?LM6SurfaSITY|LH#&~=-&1LIA0#xUXT7NDe=wrXE&fLJ5YEGLznEJtkmHX5 z@sQ(1%3a8@cUCX-VnZQQ)?gB`HV6J+Jv&6O+^bn2C#vn2JWEVB{CN{ApH)4v*kLitP z)2p8vX*!#{Qe>R+K$_3JE>8g-p59ygyJ3#^|O90 zR(frP;W4MB@#u`U+E5@(>1LRT&~R6Fqig1K@xF)Xt7OoXqOsne!b_xfwdr{Gtp zt9)m-hBmfQg|-G78W%Wg`H z-t0I55=?uMziigQCoccQX>!w5s?wl?yk)0v1`DzhT!vD-f&oO^GpY>{L^$Q%s{} z#rxCueQo{sR1Yw|_K!ISuPYBf5H#n;ytP$Lt$li30PE7hPV~ax@~Zsl177p#+on3I zch>TobGuS~Qo5(o*Hr1$>YO5d`T4*5+4|q#r`NuRnRp(WiRZsK_t4YU|5|liz!20U zu&Aa@o%O6pXP5?NxcN|vYIE9x>A}>T)EvC!y?^wpZni6UNap9ZG%n`k{@Wy6F*JI+ zy^SG<->Y(8+qaYar?nVH{Y(GoKM8xos674Jt!YcOsh>ZMw+aSsx79ftb|MwI(a*cy zG9*KSy`nOI>f9~=b<0)f_Vl=px9Jf~K}~0-b{}e-9Ecm;pI92GOj{3w=8(nzX@OhO zl+W}v?gtHha&3yBQ~|##M!hZ{JHlAYw5GFD=N{31XuqgOK03cYo^-A&Uo=|%Pq8BA zRG&Gg6ZuEJhrjbvyT7oq_>^|7CGFaq#OQ5gIz^=K>++}8Z%yOaKk`vxbTQGv_;OoI z>EuV_l&IiNw}`E#Jt;KICG~&q^6K1D8_}kB(>%itdqop-8MX5LuN#SJm7o5!nGgLx zx)*=)U}|R``Qqaf?yCR!6gH9{FQMh8ed?dcbZT*odtLs~`mJe#f#iwNMNVZDp>g>w zxeMH2ML>NFi{6s|xYu1QR542k+W7p4Z68&3NV#@o&7t{u()Rmv1 zepPG-7ORFyo{#RB}|AX-_{3Bm>x^C)|$LFV3Ypi-3dD~nX{mef)y&GNlAA4F(-=2cH zCHVdem`AHw?EbZ&0X&b_^LeKwf$FB@Z_eyvJG3HVN+U8kEe!T@by`FRGvq7=mck0FGYK#(kY~j8+U1|bg8Q#qHTVLn5-z)C%8~o`N});Lkq=Bt zaouVh=_}rxnOl84$jw;G%L`MhPhv3Fg27<$2QwJ&m<)zZ3M1i7gpOp?kMa_2Ryrp! z(q2U9LP!9!{!S@pv0$kT`%*^)5%lu;QzqMCe#-3%-NZHkZ{SQOTl5hLXv@ zij`?IwzK^VwMXsdq@-FVk&FR`G+_@Tm>mi2b$gp_#qMchLcZPAkvEp%P{Px^%#jlE zpIlKD>X$ikf+3!dYEOdSj3fs9X7kt|hnt8hD>o6fdbpwLazoXk!-yzlM<_&2rq!d2 z($OZO%AT2sl8-{gW<4_!#r@!dYW(ulKvdN=Ow<8CRC3xCL&=(*#mf5S-4=#)dMrRa z!|Mk`G1x*wnw(2O6oWO{-e%;rW10PxhXynH45ssK~2l3^H>(8S!QL052YcFIj4$V%`7~T~&*bC=MUs zB}`%xX}yj>EhJ;$055mBm-5;2uvpV;+!9A%+>*#O)1d^*!30xqqDW~d z^YOQ+5b}7DlI5JF`OZM~IuAs3)i)7LCc)gpwb_b>TNJ=Z)Lo)qxG)%b-$dy6wjz0x(1sWV^kgG z{0EVcbjJaP3>v4I2Xgd=K@{qMn6BS%V7ev{wMS;!q*l z6Dmtf0il=!O%f3+h9j!w?K?cZMI%S}?i`x~7Z+!Y8Lu`@T*K@gNfw#69Ho(OF_5st z#f!&cmQ_o|D-6OTKLa%qa*&Esc2J`fU;rUMY3E2Xc}dYQ-7#|Dx19X)JsL`yg(}N5 z!$f_yo0kwxGK}w3XuKThW*9GHha#eQhYdst3>oXAOajo=&Lk39QK;B1c8)GZ40h0u z2T+ZSHd{vrEf1p1dxgoRyrd%0Hn)+6l5@HmN}e_(j2tzk*lU8)$R_;)AWE(uXh;*= z6^IgLT#+h{%;++7j5>FLixs;_nr!q%aMPBhR?ihnVCv&9+{FdBj?H* zpG8h8%P|Z>HPFgtBZEu{hc+;Iwt3E8?ktgW%-}EyTP~=I%+U~xh|E8w;*N|W(I%Y| zL@2`1n9<9aG0d1T$Cz=BF=L)FW0WzYcB3(47h^^bW5!Tp#!O?zDaMQ$#*FT!8ci7; z4D1{2j40*lq%fHWd>#P5MSmd4{NoHPM!GwQYZdp<95Qp32pYfCCXGSPk4A~JP-6LM zlsF3|j<}5yXQISOp(wFr5lXyniLP95u@{>$*g1I{Z}@2qOc$j`tzj`LOX4$v78a3= z-rf#w12!CYu;r{0utawA*iYUREK1zEay6&ED*y{t;Z>~6tUtN#cT{&;^sV{=D zlFFq6l~2F4`HkM(vF60?cfGEh8h8(quibQ`uJ+N2Nlj3CcGIXY=N7F^Z552=T|a00 z(PQLqS&&@)$I;^NJ)RAE4ao~#${!9&Uu)BP#hH0&>nr!v{KVsJg0X^~3#{+7^!WH1 zlFNOL?i_LceeYIC{<8Y#g%jKI5`a{3{}XvN9|zn045Z2qE{@uDiu+;#w0-mN>dVbx z#ZPv1n8$v2eBa8%=Nsyu6+7!}&$Kn48ktt~*?X+)ufat1sHlwv&@;J$4o?@2DtLPy zl8=~p>+zMnIal5kJB!y8JlpzPOouaHyvNFG7Y*23c(htr2&J#x_aN*3Ilp_5{8QAe zLvsrM7|;Zz?|FaN{&K5uW2^UA(Y|#q&-*Q@JqpR2SIs!mw)pdq*N{9a{3pTJwx(+> zkbLfQx7k~lWc=3V4aac3vu^y>^6nWbeg7a;pX<9~zCv=%(Sqj#+umjXsfs1FeoHIQ z{B{6Hm86WG@@W6AgVE4-_1m-ep0(A!+BIk%d$CL9k}Ksa?;@_QesFfv&b-1R;41Iu zLp|>Nk)E0lJ+r&w^c@UhXsS3~*JxFZ02{!T0wL zgMPp8_S#<$UCs*;S7W}vx77LY3VX!W*9D^o6!1$CS2aH^SbLbi>%O=fvuZl zUQl}P6C2k~I+#6bBQR&+yC>I*3ac*ihm6f>_1#dIm039&N^gAh<=lIZH>+(RxqHE_ zM;B_Bzib1+`r*OwAC!;3zk*nwzH05cCD}V)Lg_ir&KYZh#+{ZXht&m)<-}0B=l>ZF)r2|Ncr|S33*WKP+v2cZ6()0`Hiu!2l zardKVHe@_+^JWRgijM2Si2r5ZbC4R%;^r3p)boQbL)No<{?Pw@q;=9q7~E?L*4|LA zB3ix*#>$TUQF>g~{od(9=%ppuchKKwNN(!&v^L=DoZ5zKwJ%>kZ4r!>US5@V@6g-_ zzL30hbgCbk4N2RMku z>W7ll_YN1QxQ-aZ?IU3?kPyF0Y%2tqy->7zUCWvLnKypbOBFHzO)zN-v-2`nuRz;d=PfgMwoo z=U`-OKM!oV+_LA*6G*Ndy)WfmX-XU#*?zA7nSc4`_5P5Y`NM0!sP_pYx5CJlA3n6~ zz~X{aXCZl7mc5}{(8z9C_IW{7%9>fIF@8pQ{2kb0F{l0z-g^VdcMDD&KX>U3zYvl){cy^2!<92t@1X5p+djS6 zH{e1YjO?6(lS@;F{qQUbI`XBkE6jP%eyf4xLlO&@T~MuY{|u$CZ*9yvE?j*CMz;1* z!GC`Gf9$<|JXG8JKRnL{gQ<{^wn_gf7SoulyYwQ4{~X3LW^>HELAP7>92TfXtU*u4Rs z%&{|7?|kAihM$JkJ4c+0ocsQlyKg{XPG7gnl5JR82J*$ruus1HZe&U)6i+@IUH@g_ zzVKU6e4N$$qLJ@2*R_H4cDxMV_v}p4QIKBcoeS3TmU)RQe3ScBt%`mZ66~1Wy5?$2Nd7DW4}6n>&sA> znxs?jIv?5|aqt0jmHXzxsGg`%89&22-hO;@mdh zE9!Nppe*I+l;l2-#cxY;7p3wJnyeDiDWm&b51V$B*KdXOc{hXqbC4tfA-U#4F zP*+*uIq4)o5oZ<$)VKb9c5ShmL5(ALZv`pOCy_n!GH0W@c@ZiS@HXcnpgKZ8_Q=QF z0s*ZNDiRQxV+i0z3dkOXo7*FR7pWqF#pXN&2qOh#k0Q-I5Fn0JkwBC=zi0-YOYs*T z1ab{_%e5)xTA}`X~IUSO3=Z z>Ywx8UgBMFQmJ! z6G_o>nN!fcNJ13|5EL_~As{JKi2zkG^G5`vg{l!CxWEiVKzb-i$V$DrcLMXJEu;x# ztRo^Esq2{)EdOx~I{&REsF(H|3U90*DM~%&zZgYV#~VJGVBdEH5g{Y0r+r0%jjwN> z;!gx8Ur8#&@=7uY23G2(!5e^m1)Nx^K{ofd%4euHG&7B6GB)z{J8bP3gxn~8@dVOF zH5=sShalpl@|SUNvPp9?LKmY-wP;^}^YD(Rf%~SpnKm$#Y9Xn|Px{~|XQ7ix3)A7G zFMjf&?xfVB<>^Y%eNw`n=p>99GoA=6&1KHy4+9s`G<~=m(Od|mRu}>kz+wm~RZQI@ zSs(^7?oS7LNW7oKH`zQ=sDY~`idbaH}8gYmUIX4kfkZ`kQssSkiA#n zaqTOAM-O?Ob9JAiYVsi;HeBOx?};8V)P2oJgSH+aa+7m0_g+Ik>A z-L4fL@{wSVsoPPgv~U^9@Rzp6w{JVaNg{p1l%*2FeR=^80n}P%0s;ya2oWHyWhNrv z;sUYcCPAy2JA%mSC#aK7X&pkm87up2! zUj(K*w7EsCnr7<+B2Kx;Pd}oZpi0GtF=jqb0Z)CmelWs=EFeBv!~>qx!;RpLTB*lc zxFjkAn4|CD?zeq>YWb6q)jJ_{b9wso9Dro(K9{iyrKwX3A znJh{h?IRLF?Bfpt`)PdtRS=IX>@px@lqEeH>eIf?WGu71#|n>czE#{S=GmX8$;>zX zLa%dEc^cl+d%e#Ic^1c@Zb(V5a=5PplsHYf1L<5})q^I)! zoFq<6q`n&OSv_TioF?A@{l_*!P?oK66ZBt4Tc#4F_5Wk9U78$X7n0}D||^-hFMy z`+6LYJ5+}iAK`pCp1!1=E@a>2e6e$QyoZn-?5b!H4Lu+Gs~8tOAEl{WL=Imthei&c z(1mS)Dg|7{c&HNJOs*y?&E(L?;hW>a?u#l#UB!f`lG02rB8P7_hei&c$c1f*DkWXT z#HbSAT&}*i#{k=F7r#{(p(EeJsRI!e4+0l|i!DN6$J9nohPz527EHYnVrarl{Rw(H zUiuCGCJ5VUr$V&M42)pPyCSip`xdt^B`Fevi7ya)+zjGRO6iPD#U}WSWa1aIX9lw& zp3<8JFIO<0!RyNq<`Gj3KO!7@C}wbffXKJzqfj(TFGP1=d`_I}oGteU%_6>b@>Lv~ zU^5|kH7;x!!GItdD3@Rk?8UklOEHkXPd_qJccdIB=cn%_Ub9~~vam>^4Ydn9X)p0^ zCSAwuQ6;{OTuD}1%aKK(O3g0m1N_m^YZx`uMBa5wxR|)E+D0z=jkqp-1B@D~^btl4 zRjR&@iBTodAi0{Xw21@qN!)CQ3;UQ@_k3O2Z}6=hS_k3LliiaEjb#!KB9)SU*Jnwl zzrd3_t)75FC({>mlZ8Q^ZiF&Uoywx>>laQyQ6ia(xx3QP+`n!*Gwc{>^!f*3K6l6j zo)nRWG*2A+z4|oG>%p zN0K7{i_-*w;F2LQ(jw+UB!SVR6-S>ZN>V@qpPS7?5?Fj&nF3+9ULZG#v&Do2DzUM2 zZ)m&)Z+s!u9DRZR;_NOxO6EsQhhSQ5yj$CP;t^GuWQkHiEFa@{ki?vei{)lqEa&24 zxq%}RByG_omNUl_%b79e;v}ICaGj_la2+~6@9^jgI6T*Vc9~mXLQpV0)3+Er2_V{U zz=wcc3PaGM8G@Np(GXNG;a>S|#z`~;kl0;nt^dbVGz97;+^XM*Ay90d#>|Bw5DbJN zkdLfGLm>QVdUI9ZI(!Ho%YqEuybx8VWFTEYe77j@8a!E(Zh8ijtV__ z1|z6}2(aXFP+}>io;*TV2;Mg|gFFpnxmzD(Y!lU*kY)Kmqp0WnEHdc# zTA1#3D8_UY0U2L`^>@HVfNm6E0QJiS))4edD}*cats!b%Fp)0imOsa%)-AR|)OrU~ zU=6qM5Pu*m?Y9!)k0K|zIO^nF*-GdPV{ez@no&QViY8urbPz0RppS#mtmkv4+7 zhBTcu0T#JhgGKUE?y3bV6enQ$vv~^`K)jAkAfgvVC#b(Y{nJU{1p$jRFYKjILZI6TJjtYNvACdN^1oX>a%U}ki+ zF>PE`--{J^jN@jk7)p9WN%aH0ayWejKkc_3Q&x!O+m4%o0n?<78?sd2Wg5=}9xrAx z$(u8_DZrsmKIfCrL~y}k$_*#Mc>LzZTeR^Y^<0IK%a?FG+nPx`*V@s&#YwtZM+aq* zw+xGz9APv|ugsCm~uL>=(F0 zca#*(<&1YHR;pP7{-+uuB=i(Za4(p6Ajd5klTfAfXs$J?B+e2@$Vyo(!F^QelqF*- zs+1YcbwZW$vII)9QZ`HQ0986;$(W8R6-RTuQKi-_0p)WD&V;GFDwa0|Uv;AUewrZs zJ#Zq^$Bekq0?YQ0YSdVvC4=ni?%mhi-hI8)yRYrN z``Vh_yRXH)`}(|hUz3(WUuROmV7lC;PRR0^S>@VZ;S1E4K4xa^A;j<BR~gAmcj%aknFFUnZsMpO9xr+a7QG8(kTRET3RFq zL~~m6!O;Q9eoPPiqmTmi=~H(3gKewT5^P($t^mtrxy^+ zB@XD)l;*>Aw;;7>j?FXpGvYL&V0W@&_@xprQ$Y-uK}1wbZ?UL?5H4YKqE@kfF?;-{ z?%fhpv81Begvh#DG2{z~HmYK{+%?1t$`{mw&812CXIvjbx+y-Motu zT}9Pm{L#cLs~%53Neg=LJduAJtSMBxpuvPSMOcBAXU~=G!Y9jJ!s)V4uSFr%kvNs* z*?B_R!6{^ltLo*vWkAw*8pgk3Y9=8+B2k&(j1gpC18Ww?@oYcBa7FEA6@hTpfrz*k zGo~To?67Kj0*0$)Ae`#HBav{nF3*F;1v$8IPP<4n9xg0Xtt@#+np|CxaV&yZ)wu7ZWJ`aQaH=54dT68ruan57HOFeN{1j zg!|Gm5By^(Iyq(m?dU+ z8WmDNp?rpy3)p8prO&T`eYS~_*e{>4Oeb(6?oH}r5OWd86~{+~#9U>s7d{qpS`CD| zY4;D*U5LUuZk_fNzQJp7G4JP#K7j2j_=aQGkoH;3oVcBQOEuc2I?cKuRQ&}KEyn9p z%)@YreuGPN7%tIyHeCr6bB#pL!6iD79R?D8aS_}?GIK86f>xqWSS8S_fGFZBEUl86 zOOqxtJ%A_xiC-Z_8u~WlOmMdQ-WR)UP%1RIR zrs>2whWAKxO7tEUF6>g+RBNGa1t-&U3eRyC<_E$lG6SKD z@aP31Qm1$3Si?ox*6W#zwK=MKos$O{Cpr+TPTt>Gb4kNB7a-D7 z>*1NHJq!z=e_BqrAWzKo>3%!@c33uh7+!HQ;Jla)H}3*BzuR_X4ZPY~b8Xb$BGI_} z`|X5y!2mAh#w%W3D})ZWje0>jnKh_aSh*jD$ya@q^wGC|M6Bse?9g z>ws#~8BUwt*LyxP(dr3LG(LQ{t#i|0`%nGCck^+`)){<7<0|{C84@|*h0IyDtwP-b z$**uI%$ak6u(U+I5*b4(aK5oTeD@LOmcjM{W5~N?jPfqE{|yJ2F_CwPE;gpwq9eK3 zSVkVvR~XX``K7V9?@4B)9grWm9dOGqOUfl z4I{DD#xnAVzQ&j~oW#}`%g7`8FUB-`V>xd6odZe6r{Z0I?%X=qe%kk4Pk;G+*Vld5 z^>yEMef@V`U;ka#H+w4OEUH=Vr zT?k!wORIovW|fAQ*i(&~cu5BEk1(A4ZVcA$1|A``{8t#4JPLS)S&&D_)PdWCtAN8(o)pFE0xg$0mDl2@3JJW6_n%^{D{UST5gDE$={MII?$VPf(q z^A#3P9_78lBn`Yw4AMh=(uqGW{Vj|!F8*|Qotw8`B29MV6cwUk+w=+wSROiLqmYTO z*0_ZSR28qESu;+u^*y-d^hZlPvdDP{_ zK0qFId$A9ZNBnW@jFrSkjr(3{?ZCducMu(8&)^Wo9_23F*dw_F#-1-4V^4_%BDzX@ zA`op81YNX1L|1uF1mbN{;bjX%bXEOCAlW7*ESRn3(qO753sjHHaGdDssK64Le6m>g z-oZ?J4WDTf2V2&x*E=ek#lf>dm7CG+rRMZr*`O>*>Xy~7M-gNtmyjt59&S9w=n2a#1?$3-@X?oNm--Nxh#eu{9CWs^h8Yx-IL3pyoE=~PEMLQrH6!$h7 zhDLG-ADZ?u;6?1w42}H&G&JI80)DD=hQk0fG-=NS!c?iyX#g6Uyk`PYDjJ&V>#88> z5BSiiBk-Zg>tjU>4gE4J#Q+qiv!L0r^Nw6(ofV0B1QTAp7Z(`iaU-};-p2%XFX$0}Ds#VLNnpFzAJv%F!S1cpy% z#Fku@^T{K{Rk?^fQeKry$RqVtxq>{RT$3w3ZV;_t*W}nX;v*F4sS^~XQH&>dYV&o; zWGw-p@hVTQEUU>4v>N4j{40~T<1d^@3&!cg;)-A4o5cFfh?lhEroSr`;rrl46g!@v z*4BA~u88lnDUFn#_ESA>poS}JMIHSI`gs&+ygsGi*nFCN_#ryjaa| zFpDu*kPDk733BnU=ko01dCXcZb$NnGtx{w=^~cjpOHf14E2p5dN=PVVl>$v{nWU;?|X7%JCXG<>O4_N(PDuQL_yRf7?EMQa@rOu`GWTtcC+_77K zV#DZtVAE5(V#8QRZ6Jk5+UIDlL?O#FI!8Ne6t^4mftRS;Onua-9e;TjQGMP?eBl(= zT-|5_ryBU$-x;b1s?S4-Z=w~Wzz8f40-uf{#Ya>Q@yE*D_{2Jo@`8SZ?W``=z91C& zLz646@UqpoRvfA{;kk5b{&^12oSuT0ft7A$-vgTD2t z>e~cE5~1F@a?o*U9Ec~i48-%6Vo+xv(}k==)4ob4)h4hE$TkIhbD5?b^yRN56xm!^ zdnT_v;~J)b=V=kOw9;Ei|IB%<@|Ld`Z~e;c7$O&A6rAU8xTPkron&VN=7!kX*@Ief ziV-F^U<*zxbBGJ%HN$oQc}?tAeIT#dorIIu+~*+jn(=QRnCs67Fl~IaJBQc4YWY6y z7Jrw^u_`@zU|S*^=dl^a-W2;|-X4j2q9JE6_Mh9#udH za#+JpVK)YlN7sN!ui)VRZn(`cac?*}60!ThfN{zMrvYS+)le7KLaEr=VrckFoEj$$ z?6PD1Z*c7_f20pLQQ`uZTJU~Si&9Ig_@*Qlo{AGeQ571hCYo)yG;0kvK2!wbO4 zcku99<-g=N32G}{V$Qhj{;%j6+!rkp^S`IQ`CNZ^>%V8SXio|o+>2P}jFr8H{-tKs zrrH0V=B{DL+T|}zCds@D14FyzzhQxIAAMnm4Xq6BJWn0V-4y$5*n0}W8W;5PeBl0W z`Z>xEJc>_j8n=?hQ&HLrXZm4=N(NnB%L1y!&kI6J1q6)+aG_Y>6f)mYC9aFDFaoN> z)xcfWXDo0!)vSK|XjYxXAIK^Tm+xOe@B1xa!K3cjDqW!qcDh21iMeVW)ltC4&VnPp zoTR`-zi8MhzH8q-)vBjL(*F6>oc~IJ5#GwuGQ|LS#ZQ*4|BfjJ7-cpt8i-Ro4KIW@ zag}q-zoQKC34dkcFKhqw$G>6aY%5_pIB7_r)J1`z9yG+XwsCYF!RGlLN9I{89i>me zjztNG|LcFCMXQ|4`nNO|Kp|5zCbrZBz9Jf0zM7(U4^9sFZ#)|Tv5$s_;~#|Y$TzMI zCdlprpFR0s@_KZ{USe<8k>5OGZOHf?-E#TkM$ zAB5oBZRh_YWTY~s6KFq&>X?v@5HCf?rl0;T^NUYLZDbt5RrLAN{t_h60gImP`fmd5 zv$O3)w+OyYodp9O-E?$vzzKNO%35>v#|3^mB9OY`MT25Bq+uF~3aRag4Fe(lj#)6{ z!K7~dWq4n-DD)ftGF*CgIkwRR-k8t8a5%z8M+y^g@V!s$3*G6W-&}cGrd79h4H$4PK#xevRv~HD)ACoa{}EAb8I(FbLjw_)b{_a2B;3;*G+ZBl zYseXf4AQa_NS+aV1)p&022s77xs^wu&Zu0EQzEGAG_(n-Vm(81%z}M{m)0cwjJM1` zq;1*L#GK!Yb07)lf{Y)`9XP$V1Af~Ge7Wd>6jlDLt$X9xT#@&5CP8v!TgY^ zVRgW^YB?o1|AqO1t6_eSi1Di{H*2T^x~!a9-Lf-Iuy>wMu8{C8lFB{O{O6`$Tbol% zFf(*5({eRXHUo8oAEjWKhEgNokq9Rb$ztKO71K2g50osGmU85yn~to5dlDz>P`hc# zIS3eGsGzxDv3~zIZh>U4yY0+ zwn-J5o43~7{C%F5B}QVw((%Tu;(?g)jai1qVH8@15_lD0^bJ|KT*0P@xf7qq^$Wvp zP&eb<*|Z#IC6cT+#7>5p+X=KpF2dzhTZ-z{NFYwJzl5_Z94%4;F%eJ$O)*dKd=V=E z)(Gx$L*SQS#fJ)kli^~m3@8^=Pc>Wr>U|^Q{{lR zN}ahipU3t;`;(zDX2Hnd=QH)V9NLVU$1)!ZEetbgctU zxl6k@k1rn`Aa}*+qs8(u@p5+|eM}PFBaQBrF89ilk5kIMtL5XBxy+3ef;8hJzt01Y zKUPoaSGZ!}?xkTED?+_dBxXV2!eSvSvYHjw%8Jfo#ZdHCV0u6E^;V1Yev#;{Rp_l3 zFEZ=^M=z;;GB_;!EggJfg~OYj!=f9#oF0_(IoFh(N;Rh%GpbZTIy9p?wNZ`8=rP5( zfp7dmV*FZZ{ElMsk#F)vdD*;*#%-7KxSNF*diL0_eGP8Wclcd#%TR~R8ZrOAzd_Bl zgMM3=oKS3$S{y64h#%x2Q4LV29F9>3?2Ty4RO>5J$4ZrTwaT$Z zHSjUj=?cxJNovzV9rQ`Caiql-?X7NEY8{=WX#TArrqqc;k&ohd-7&nQ=F^@-;@=ULQDw*L?^?-kG z!$=Q_h?~?WC5WaMQtCsRO@YyOlCp}@0hj#g0_a=z)$g=!E6Y+KPh zF~|vbOXx_EDac`hM>6f+DdvqC7C^ScpZ=o-P2n?)L(Q}4mq<};2WpO*iUeCqD@F4Y zhEXU3VV%{F5#+>$Gne7tDq;J}JQ=>2d&kUHBFX|h#<#lh@`#R;+n#kn1aU&d z^6sB5mG$^S&~9sYc27Wx!`sv<2=qSWurKlLwuT`tlPGUu?&N-0bmASv_*Q4XyOw|B zW`+k8FCEa9o!tE)7aoG*X}OBOUb<~N_Zo`Vt=#^_b4SMMe@xt5A6ZZicvnsEuB4AI7I-M> z-jx;O)N-qI`uJw}Up}y!YLP~@&V&C2)-BX~dQD1JQ?Ex9g@-}wSRESRGU{jQk+;Ur zYy1Z;qYIr^EQ3zyqL7>SK^Syr$*0cTOskg=Y2Uh4?ftdA>X8q`!gpQ1*|utA$NAgP ziPwQUvwAule*6M~`8$5c2E5#udL079x$irYJN~vV-vq7i>b!BtAz*9XYbd__*={?- zq9Y4{p+lN}bTPqKz3aFb2Z21O@7beXqG_%Qiw z61#>+;~Ri5XnuMI@%5POtnEoc2Apc&jCVe=Gka-tXVKg;2&ON+e}3D;AxodMLa2T& z&$*|$K0Cu3hGd1K?2N~uab9q}7tPmRT9)$uxevWrbmHWu-6cj3Kf|*s_kLDp|M5^tN-TdU-^fUaOA0;mDZ+XzSqQ6oSaz`ToqsQ7F|I&x;H_2*sAka&Y3#+A%1N-Yq&m-?MXl;P6bzsXt!YGBW}fd7G4*MIgfjGK@9sA3T9!I6}mB(E@R3OaB&tN0n`}g z76e?GB|?BShPf31RkI`rpvE$PML^9g1p=h8%!w@>T=78o^W^Nni~uQO zOgaPD&bl5!$xq8bhn0h72P*KjWmMo2iytYJ@Ug^l4{rTFSf%0Ip7>7~Mgi(PQunwB z=T-?NlK&0e1N*)C_1!N{sb4_%NDs5m$e;G2Aiy<(q3+4u5O?^J!A{_$Y5x7=<87re zFNTbsBx%DqMcvOz-b3-Gp6DKIVRwTE6wfor{-?IP>&NZDf27~tSb1mZxZei@*@26QOX<)qXZYB zh*}!T9=@Dc`vzxT$#%$Z_ig;J7iV7a-QzqcwY~Es6qmgI`PrS^r7v2c__6=!9{3FW ze^2)iN+l8r7pPlWU;e%9diIn1BS6zo9v^Y^`Kz4umN2=dMcDb08#w9I zQqVLimv|j=PQBjP48?61X!xS`fmSBy(rO^NJ^#@;AUD_l$2teG!P?ZPZ-`^El$ckF zL#hoJ`z$apjRkA(iI4pzB>BBG5EXwcHGik)MtS})wV6ArDT}|o(3H?WcK5!n_2k5#If=8hhP68M01ixT8=mS?lAK=hd z6!+2xQf)b76Qs_d8R+!|mE)wLQD_8L=_GI!-fO;U3p2U)hT4Kzq@tr*|7(SK0^b|guHIRE92%@!K!jnLFTtiKU)uCl)u#I=@gZ8 z)p#bcy-zj#O&8};tTIGVXU=XlLX&yYJY8KqcYoMEzNh14koun5GLMjvN^>8^S7MlW zp;9r@7Yocd2oQ!+S0LcLnJEIqp;8hkHnT=RS}1iTs!?J#6ajgmQWCgm=7fOiQ0mX9 z#$_`%1hj@qNkDDp4S*&rH&Pu$4GBNnpCI($E!<#>cN2oaSgDjHTw?~V$3({dzo5rx zm6q@H7_Dmauk{!UxR3t_^%y=Uj7`Z?mvC z9L!8C59REq7 z^Zrpgu60B4U8iCjey-oM80R+}QhV3Ms`jl7&JcGo#q&-|erzL5HKLNw^BPy2D)$CW zqRx9t+qiMJoU2AfFiy4PkMbybzXR+As?yRI>D!8u%KrhL#%-@p7hTf7z7l4p_OK^K z?>){J;{1k7s&Y1j?zij4Suv+xc1b?^K7J?8b-HJ4YT}x`FL3ihbJ@4tUv56}YJt|f z&b(>!IiIy0Cjd3bYF~KxZAmt;3i7ru9grWktgQ)dM*TH(tNWqL=Q^SF^&{?H9JsG6 z9ZVyFlG>#fSw-W@!8Ae;USdhy>lP}T><`^Dpgm*^0qe%cVe4+Xrq_>w%Qu~lJJ}VO z<@Oh}H@vH}>}K0|8=xzdI$v|2fAVJy*JEqen)kWb!1638-V#5r+p#Wj&08pb?Cy4# zsoPu(zyKso$fOHh|M+9^QsS!Z+U*cHr2fYg&Y69ET;MF^mJF=r#-Sg<7Q4<=gaPnVmUT5)E?VpsI7XG^VgJW8vM z^qpTVvC~5V>{Y>BJ%T`sI{vDFqDMUM&!%iK@#z!=sz~6gDHj2fKmiri_-1N>07amR z1UgMI1SkUqG(9tW*E&VN75J1VMFxbf-<1JoZoFwcVvrL|pL^spApCU#gUGacF!0Cz zZlb3G?h*rkqS%2D9=pQmmo&W!C;ANdL_(-XgbBli2wXaS$AkE@IE7F)s=X!{ZB&@C z$7w&dYxv{*)9|&S0en-=FA{m*k$o$;*WYv(WQ;VLLuC6HedRL`--2JfmkA=OS;LvY zZ^J_qm-c08^A_0W!eXxFp)9VV;8s@Vlx>RM+}|+;=0kx3o5i*M*&XcK3Roz2%M;15DsiLY)DN3YdnRZpQlYGq zpAn7LfivMMzOM!=tW&k6PTL!f*TUN>xLF_n=7~=({JO9(|`1;y6U!2Ibp<@uq&6~6hr)9pjcG}(Sg(q7s0Y$0j z!JD4nzMe=L6~XZRR@-{%aQw3dXs&AiirAc-`q!{t5Vd(3{k}NnB~Dl^8SB_^Y(&YC zRJi)OgC8&WK5FX3S)8X_I}`icIL zbL=m`6d@=+u_F1h_|=_G*GbBkULE^9Ep>NE*8h{d)<#)-FMW-mI15fFK-$)`hIDyhvJJX_k^8p_?8KLsnmJhBeGt{ zjkpOP@wALhmW@tvdoDw9#RtcCE$pwK@h3Oz(w$R>b3eD>!--)cj7AQ39?oOCGWyv+ z&0B}209&hI;MP(5UW^^p0TYxkJE#8S+4Xl9LvwkJrysruedLWBv}31Sju>%h&1-yu z3b=osQ?fqxK1@)eXKuCGKS%AvC#d8LW%+*xoIG0(#XlPE#x{54i~z<{>C}B6x82G6 zrkS9wMO9_3c0P$uP|1buoyRPHU-t(3w!|YNHneer55CfxUvl|C`Qj(FIPLA>+_}5g zr>0zh;-4;!XuMSN<}=Rn+;yh3wCiP0EnfW4q0HEmD~#~%4av>(#vjF(cMk?5fX;R= z(ub~&2(0+p(vZ^$yq zXBtb`xIMbLWb=^S!|WZdUGD~8;7jWblYKmYc1_|)?*ccFNJ^U1;QuWc9)Plmdz z?ZmoruX*R-$>g2L<&I>Do*g`6p#Wi0%-# z0}<-;u?j;{Ij2T_0CS=mw?cB+gmS(ClyiV4blq@64nys=46XV;JiT+ z0-~}NBv53KjDYxTH3<|Oq#+ zIr0&}FOW4MV3MN{0g?jQ-w2rMC_;d;K=uLw(;dYKpcKkpA|S|7f&hM@tQi5b9McdW zDU`iJz&u9<0+fZa*9chXnD;JhH~-YR$2uaGEBq744}O3FIHqspgAXlTKe~D=F_f7j zH?ltxk4t%;W;=*@j&Y6_2%r?v1|wjlqdfxnMYJIZSmWq{07(&TC<4|w1|UFLM6*M{ zCdWAlpj@O4LqNG>6ax4cX~Pk4%`qMUl8ZEZ1l)8?LV)rjZN#8R{hjANk_>;2gJ>Iw zP6y=jl}NBZ8?X>CCs#rOUknTo5S6PWfv*O}2#C+6s8Cbw27M8bl*=c9ZwBTFNY9m! zK!<@P0y1-zB+zMKgMi{(O8M5gvc-xFTloZjr2a0T)H64*1?D|%qynGwIA%=TAWa0= zHiLF;#&Pu+i2sx>1s~)nEyK+f{+*Q&EReYaqI%ZK`u&=C&`#n!)N;*kiY0_TOpcG# z&wzH9TTu7erGZ=Yk9{%>loPaR$SGX-UXG>wkLLFiaXf!Wu+5s#^eA#U&;N34JN((W z2L9}h?PSUVA-T0ySF?n644tEQ4(^?iNik1nekYtqR>k0^^|T1w<0%?9zq{e)cm6dz zHHY#s*xnb$fW2L|eI)GpU&Rx`gu~i5FPE7Fekek&XD`Jr45Lt%f+4;iW0pp*GXaIK z7f(HwgeE-!@U+##J0E$$i&l1F5Nn3zpw<~tUuPWqbTQXn}(O)`~8}{++JrHy1lpKymI9 z+k^Qp4&;A?;x$XP``kD?`1?2a-U@=fSG(!vz}8-S}~LN5|cQQ@h1iT{~ErYh|dx$S@8dgJU#zGX>_YJI;k5f@aL`* zf+TqEswcS-&PNw_yn$g-oOkwq>20zb1O2M?b#6zeb?s4H)Q0ST7uD7F)frDKu<1~n z-t#evNw%t$UC7JtFo*$w*uSx}6b!x+knd20A+> zG3V}G=eI|PLhIErdk>wfl8?l-w=)LU^PkzO(yd1x7fKEKlih5H3lKcufQ82ak>}`-_=!9VZ z%dfa3H-ZI*-_d*km*fU0UaiYtfeZd-_MM+o?{0Y8Nbjx2tw_d`il`QyYkVckh<^>e zQs$2c%*kP2!!vCt+Q!`+`d3dOJ|YGBAJ5#}*0cy-9O{p`BJZ0WDd+IZ-Fj`bwJ13s z_g3i&n^K!~`)CoaHEwJ!>9P89Iu_|?`xP0`qFMPY)jhn!= zL2DF0;%2N8%RQ4TKvm?Z#1)>&6L4mG!h}j4WYkP7b-7?y7vO#b#niWKyL9?2k#V%t zGeJ>hO9Z3IpXyhI=}6clzf?cBp32VXyQhRk+C<_GUO0%-4QRw9 zotb^dMz-=-%%K=5E6VKk>=xptuw0r7we-os0s*32S~&vR9qbXH$fZ>vpv%Dn0qR`Z z6$J1d0}y}}(5@oD-*FBCLHRZN}X$D(eIP+?0%5DL~Y zWeW)7QRZxtv;K`4*MA_|oGO^TkP)yq5qT56H?t@9p{N^}%ozfVu#|oLvp;na(O}hV ztd;POp>AZ-5l}Omj{xaLCKCa5vqcD?Zer>oplP-Q0n$xO{Y4=P!fl7TnaK$jz%o*6 z8}0t{pNYbvKETUu3om+-faE`kGh-aLLe{=yi$2P0gjdn8M<$Aac}4p{S!oBZF51S1 z4Py{xRKFgdD1kCsz}}o4MXWE=!c`>TV9rHAUbsMtYB-r&AfP&2B{koofHBdM$X%S1 zra%uWnQ1VO0Itl*Lx41y$w5HXoN5G6cQB0+P&21hF_D?Jt8cZF)J^9VlQs?f!!)eo zamr~Yd39JG6WqzLDZ*1q^YECn<^+f1*j772II#l|)CY>LfP2sbm_q5&6zR4Kz9K@= zAmXm{Otea@;en`L)lalyU%=NgSM{?hO`HZxyqf(SB|8SX@VHzCeVlmsP>VQjOF66+ z(}xkVEFMl=L~$^&d^`7o0fSgrD~1o|ZJEjEO{<2zF$R4k(2c7^)AX^v?g?x{es_fm zod*;0t2)S}{YpG6^;}YZ*Fw&Pfzl8zpG>^)>93`X6U3OQf0BwUiNV}#&Orbrni__H zBj%|*w@`fPUf#^%f~vlJ@c3!v-Ek7Q3GWGaP|OZNXuYOJ zIJOd*;)}OTrl5=yjpl4M(Jg+QiUj^P=c;EgaE4il>S3gKpXw)dR`?!tTN+c4pP z1Ui+1@2gGH_$fWpt7v*KW)ts+O|ebzXYHi5LP$XVVA4AL*CX?_SbM<^8`ETneh~~Z zgE)NSG3^ShPi(<2X&hV^ zU;Xfx>k?9`*BEk#P$0GBlQ>up#T4Q0NN@4=(1EA09;%9iPz1qWBQeNndVvY9&$wqi zOHdDMBXykNZk;QW1@2GJkrhMALuzXw6#Mi9&uX~6|0FiyYkczNKi=xNU_Hifd<_$B;<8BR+$ zvFpAk1o~Haa`HJUF3{V(6oSV)Fb!)I-AKyR3A#BePDn1|-m&d6>9Z8Di1SKN)R2p~ zXZkh7B95}*ao9k75x1t=;4xanRl*`JYOB2pE#mm+WlzXOT%`lIfmq%sdM)C3-!0-O zV3+4R3`L7LVXsBp!|xVxQH~qfq^bA}jkJZ=J6Iq9mSm(Yyura90sI^qX$ybh;DG>1 z4vn;hzjg?C*J}~??z=@?W%l$zvYoJqdyE!w>g&m<|-Jzv%MmG*DBPci@&gXPX0C!2khWaSAILBASUT5&xQxR?b1JNrwF|D(bFe(L z;M&flBcNao9|6+sOeO*@&JiJin#|M-H7z$MTtcu2K{;CIjW)NBA)X*VLPY}Z<~#(5 zA_Ntv##nO?1V|!O73OCApsDK3ToW*N=O=b!^dDv^3QRD#g;h8eu zyd=u&V#70{F>7=NFzis8s8~>G2nGz%KApL%4m*S{5^otjDxk@Mt0UTPQ6EZD?wU=7 z=nwE{U)YA&$=K zbBtJw3o@C0(ZpaT&Ep|JmC5u+K-xSZ0tBa+lMs+TPke0B#3?nYjG(Y{#9C6C#oQ$f z#Fn@bH&h+18d8(G7};8aRgkqz8&*;ik}jE|k5*@u^LRr-&isnHRbXC|BvhIGfuvQu zY*HL!34=F{doa$e$BpVf4aHWT0$;3)bb%9xd07WQ2d9qm}d|tHg z-fMG`^%{s{Vf38C#C|>@?R>w#sV#S)Dp3HjAhZTn{{F}8l4by`l|Aq&V z@YnW6{I5y)TF%zWG!+(TzYa15Hid${nqP$nHNVa;u(3`^#g}ofR03enXfNDrV+{*G zIPpnE{RSrrW(?BsJuF&u1?LPImGZn_f7fqGu-a43^v5Xz`k2w~hro)kAKb#fT>s60 z^ROWEC}a9#l%qwM*0U+8l$q3r!gOJai$rWmQ7DAKE7uq%QMfhuqHt-EgpxpkMWKY- zUI&Z9Ry<-_E?f@nBo#5I$*d;{IDV9L&C>8lojF%ZNvMTl%2){QmgK^skgvuUg)L8L zgDH#!p|CQ`8Oeu;c!P3&A+x{x6xfd-W736%c(ly@7}ZW}>kO`(OYwn_ZHb`yEPlhD z_zKb*zdaG`;Ze9z_#WQq^Ad>uzJ_ntMZhib&N6RM7+XcAjHHn~nuyCTo+n}`G4Aq0 zCA5RpMN8r4`2~N{hERZHJ)1d+!l9j=WtdE{VECOex1d7A_dIRI0&T^g3Y8R@0J|0q z*V1eO;(lcKPZ&m_O9^iI3w@L`ve0%U6y!M`6fvo|sUqbHSeV?ve=NyCS zyF$X=p2VDz6JMM0eC+!k3~zZKAAb$JT||%EH!m_Go z>pZeu%W&tI4Jk`4Qbtx~gSShXEbqd8_vn2~U?bv_{j|#)wqM`;6p9zezd12Y?tB;= zRjEJR3J<-z=5_}Z$BG<`4*S0PnhC{=*E(nTKKq*Y8kF173mcld9u|y%^zGCIkH&G@ zoC3eVHi!64e9U&|rN85js>j}KYkco>*aHt0?lL^rQ!4Mg0pU8-<(D@*$}aZc&Sau} z?|eEkpO)g$Jquk_O#?QWqrjH zxf4$Y-`3LYg0UfPS-M$XO1(Pkxh84Z%P)isV@)Y%TL zcRJjCP}p3HJH=-DPRZ(ildW-zb)PO6~IOZVcV+ zq;rZ5PuWR0#R}%gky9-9DCrb?T6e?X6f0zG3Bp%SdP68ocIJ49)oiO?hFa9%Lg@ifw&-XfkrclTmHh~w32jHiQo$Hf$&5MfEMpChW_`KJLD zMgBI2V0pw2h-x~Mj)2w>J^}=ROeO-_LqrHr1v2#zAekvafM5<&9|1`-6_1I~P1PLc zt?{N!282al;)DA1jpME??-g{?$(7EX#CA|@89k59*DasdV7>+1kUjywB{=&3k$@EO z6?RJW?6BGX{2B{}Yn{ljLht#;G z*g9$Kx-59F*F^JwC=9eZO?CNmQ|wvRm5|Q!8hb2me;o`sOi~ow()su?ZJ3VL#MBI? zMas63_KAE-Kn!kk-3&e`k}{l_rTGy)Z+6Xt$j^-`D*-==V<1zV{q=j z`zK?ze&1kP^Cca7**kG30U6<&fIhw$zTvGq8vx&G)!VWo**UMqcYuqh$B6YxuAT0vM>y^(C+FD{j4nS&Z`F+~bG&3-UKX@yat(D_#b6jDHQq z;|^_azWa6jX>fj3CGYX@Dszv;J$3nyJ{wnf-}?O$%mj5_Hnkbko(%m4H*>1Ap}q9( z@DzM+zem>6^G?0#9862f{qLV`!`}DcE;s2XN>hEJzvgs8@siCCa?4)6yNUbI#1@rZ zc=o0&6{e-?2T}8mzG*MTV^q7|_?9|ecs;(?lVN|S4vf}4$fR?lD$#>HhkIM-9%NS! zOsP6r4BRh)c<$}cdO_*jol&_(yl&`fw^9AMx})y(q=yc>WEeS2IX562H+NcD&$f~k z-Ei;74FaCb`)u#PM;Xg&;khfiho0)l&Aq=Ex?1C*zjHTj041 zE*x$-C0oVL26dqIOYGcp5U*JX#gjazUb7!@%^qil2s6@zi0s4X_hhZJ_%*oqJe|z6 zgR}UNUG5Q1gTWV?lIBOHX@1Ik^r{6^+3UfBO$BhOdl1MYF*!yyqIq&@o8E8~;IEqV zj|wPJ+=-fSoW(t@NU;9%y2 zfV?1T8>-=C=7xakASnqrn|ULkHHi8N)!>==BY+z$B>`8n9}obO*k@G3%`5}~KrKN6 zV?G-!mKbYCA7h;~tMOtJOFJNx$OA;X=JPURAdP1_lQCm5Hq%f_j8eO)6=TYc3YeC) zd*Ovc(F7P~1|NLsj8AzcA>@Al;zC2IGf{chOstuh6{T~Tas(Wh$wvTn9&^k8#on8T zHFbUe!kH720AUJY5DudeL4$(ggb9P90tQ7zoB}Gv0gP6yTEz^2h!ZNcSgi&`i&F&! z#5n;(id3UgrAjq`RjSlt9g0KoT|3D+(BHSed++nS?|q-=KF{S3b3ALGX`g-eT6?X} z`mE5VGR~AR;9mfO^l+Q7z!o{%TM0_!nnV7Cg3S0RV%^*|U8uCh@0L8rPB~UM+m6$wW2V*c#<^yvp)PBwF$FeeZR~G-i@`Hkw$Va zxdtN1rU3GsEs&3U3-WPCQ&zs~laG7bJKm*};OEG>h4XXp#sNP^R3Cm0jg5()!{B`* z!js_VAiDsD1Mzd*0e+6CrMw@FQDHw}b$l<*qtoHFM5+|lRz%cT4d+Qj*QGPzu^?zH z{7Ukp?QD1|Sy|>zNMsVioOpK;vN0Z1fsJuFtqs{2Z!8rA8zLQx8S5#!X-&I*B^t$t z-#K$b=j{7JXSnlrz}Ek+)t%qH>Xw3s0(lc6bN;(lyLTj%Np`*ipF}14ndfpA{Va2j zL=FkNd+p$mfF-dp){O2G3=Rca{O`XHR>+@54h-zGUJi^KXaT>l5`xaMM!q+l2b5c` zV=KYkphfUoeKO*}n2s>HVYy%rk*@M*JNy@a0S|aG#H{8eh)8Z7@->8@EUkijK=lPd z?fNBf{&N*@enTZ6Bg^bFMaHDl5}@(o!tELVk}wjf$GDDFF!x{d8m37Qle%R%~I71lmSm3AE-|GW6L|l@$}z208n-euEM=RZsUuK@>u=w{WZ(;pu8go2Vi5xDVr>xqVs zH$cgDjX07NQYkw17LuIU)lbszmMq!>I7$5>|95`|xP4j$%?)iQ_YJZmSt|y8`ZoHJcB;nsw(0%2ZuB=MPcinbAL^ zOx2U^WdY}lV|IZply`pM_A4LOw{^hTy9U?y{9*N?KeB}`zUw*3`m^T|vV|6%|Il%0 z^&jVDX{$l5Sx+6svGwM2lr$3$`Dzc5x@*i?Gt+)l#xwgvC+f z5f(Y*411Uoq2)i?6G~b#)~KX$jl9zSBCcb}QJ2Cr@d2c*2*?xZk%XW)UPcI%6Y1j#!HIYUAs|npPap*4 z@ftz^tXh)@!NqtjAs|Q7V+cWQynzq^wZ&vY&=^m;ir+4{ZsG_*Q#?iplri+FgrFr} zObE!a^l60PalDKWC}XcuoJ!0|VL7_RbhX$UP}fxC6N>sJ48t{f z@>Pap%3oluDiNr-Kr+Yzw^bxF>Eq(wlbCT5R|2P)NWY}z;MRF3VZN07)Wx}s4=qxx zl3*??oo3+R?j*@9i_3)hQuWnf+9lk4>6)(8!&yzXo!o&-2zWEdw+O)s2YW&wnxX6! ztaNZC1kxGg+r$y89lQyFe1>v`Eq;wP4*tS8{a~B`L-GmSUnY0l$DvI;s~KT1yiFiC zuLxR-lShfh^+rQ4_mA`p^ybH zQiKS{f+{6EsQWvpiV9zqV7%1Jqt>51WRe0Tsw7w7@{-pKQqW%_4y!!e^bgG91~`KfM+EE*WR%SX6y_fDSmg_$(# zsR5M`h=^w)1H;KhF!gt@%V}2`b@3l^=283V{}SZuIPo*#>&SOUA%Go5U&jjM>rie5 z-Ht&LP8jK8Gx$0(`N$cvHF*YP7SUC&&BvzNcO+7UGG&X{gE`^gvk)~R*-uwX1p#1*Ad8;GqHW*Vgt|do5UnDf9W;8vSd*#n> zO{p0J=iRAlG^J*cx1RUeYfjB5z4DYu%^HN>W04Q9=_)KiXC^?e1fhtn`s^2^u}S%`ws0*U5CE; z=n-;rOz5kM9nd#dPW?H0&X&t*;F40l-(CFd%H@f>!LE|Fdq>r+ebl}j&R%;V^ial# zO{I_E?7wE5{*Hay;{-~gd2n{&57lX@_rR{go(P=|6$=g-pm}C+iO~syA|-)F9?=SC zl+nXB`9s=1nJEEQtg3hDXYz~mO-_Nd@sdHFVKyj3rQbErjC+M6kH3S!P}_hHS581s z4veLi<<2`Y8b2&VIoTD~-*}{*D!PgC%|`c}i}~PtmHe7+LxLTn!p%s^WtOI@dH5L-V}!Z*R|M)tyAmKMi^? z^4q6FFG6z(Ys=U#pFRHj2{b1icy8nfy#URd`=r~HwCt$7*z@WeHLW*Ix?enEJpj8} z)#DDE-Ds0S^@L3g7DlJIJ$2V#la5bfZJOP0MI`g;Ea_u9(1+x_L$AL7nnTOC?@rl* z7)`f-FFabg5n#1}xIuG#p?*PTF5W zc)MoXL`vxsp{y`!ofQsxDWaA4wJmtiMD!Z}8YS&lTY{*g!y%3~?P5fuEe7XY9)kQJ zuG^OnMkh{$6MIjTQe)suz+c6(P~l24XHPVgFg6=|qC_Q+a48f!a0!7lk!(WL* zqi0Iv0j#glwn@u3`Tuh~{lB6IP2vHgjmdNv{u1dv=#;BCzM?6b-$2ZZIlCq6qL3L>AqM8oR`J=a-v{ci82V! z{K6VlBP|Ks*pq?0*!i^vv6#ngQe5uM{3DQ?jq>BE?$F|=!7%<)&OWb>WN>O!CduMnkg46pJr zr!2Q<_Z02g{h+EhQ=`L&sApRr7|7a+*Z8|Lu})*~A_xK8W;GBQa2hmTPWp_F*#uo(quc{@l2X{;PskOyg;U8bhmJ2csZo@&%y z&eUXh6X%)cm5o1_;wP@B&{N@uZrS;^cP_l0;LU6nz6*bwvs`*=#^ROrRue^t>nl7X z#0$(9F?elkscZ2uOH~SKySL#CZ^?6}A%-cOMrk z%uE~VDo;XmwqcN=*<5yvS`oez>Z8X(620UbbBm6~^a3A^sb@4pMZA!ryRyIZUphu_PMK?O%zv-1}1PTJE-RU5>IlRmT1%Aj6o zXw5z!mKt~w(gb8E!kCg&Rk$E3yI`R}onTIv6m#QmF35I?^<||vK>-KVzTJpeA#0!A z_I~NIo#KLpBK2_A_E;Dpx=0n&dxk;}`1hf?uxHWUJp;a;;8frvMi&X8x(_w|iiWDM zqcdoBLY?F&YA{^P^yR!x$8tR9GEyv>;PyOjDo;#C?4xpW$A_=d_wVNUGsj zX~unqe=h33IP!5vu{rohg7svsv@mNAZpI*DNIaQ5N3{}v4Vo+BsIT$&^PZDU2uiOg zdIcXixr9J_McXU*$k~3^$uwyU4Y}ZrpSAhQDvPFt_{aK*BAJ7a$M=H;gh2a=5=IDm z1`7!RiKfA9@*Tpuk)@3dZ8a zk@5~0_>yNNZinXl!CS4_vQ;oew_0_Efm39TmUYSYi0$!rzkXW=zPavXhTo3$PkRB(ik2MQwiBOFaqK%(RN+S_ z+%!^f)tXiBzoQ)7i7VNmT-rg4V*;Pxy%rI2rK*#Wl!A{!YRkafDbfj0{_?Gb`nT-U zFZarZv#$Oq)r;{r26f4|!oH{*efD*T1iN70^}hRFhy(?&FZO1icCST(1F-K_-+k{z zf@0Vg-`uC2I8bm5pE59ewd?_#yB^(>4I93(&<r$(fTj4d3?ZXJuvi`~t&qFg+ zQgl|SxCgxkn9f$}sUGInA?Ah+5g)?eWgb^G#EY<}7X-3sL!HfxSOr&Mp^0~UB9}r~ zVnX4B@1yi8#%uh{{CybsYMBTIipe0F&81YY$NPRoAd5{Lq`!)Zh=XK~YD4cq zbGejS;-GbbEL-BBuB(`oILP3rmIV@5l5;6_>-|#k=|b$PRuuVZ@c$}Y)wx;Ial@MQ zJMmPXa|bytE;Jj`R5z}Qk3YQl8kSnP)cjOfKWbqXanLGmqEJ?df6=SB6lE6vjLlG3 zm^el1H4IBg#kEU6wWjx=FSrynagZU9Wknp6a}5&{2Z`O(SQ2riwOq=t`S_JO0$Kd{ zAtl=ZZYD(m66qWUAWX#LAxlBTF6K4mjh{j82U!6F63RwlC8X>ls?hV!A3OYj`%H_b z_)f*^o>ZZ4-b~dL7%Am(yr?8iTi!UbM;1OIG`SqsWQ&ijCulK5KZu0~HIyNm26ag8 zD0k4%@L$mJSyPG0W%e?Ba!DOi6#c_j-&2V6uhbo|YK%>SF&s_y%Np_>82?-ntOd9? zme{MdptG$(XA@_|vl^Th?7o&7{mQpEhvKkJy2U`EX0d`7SW$M@%@K zU7Q0%4yrZm6sPG*2i1dvV1%S*%~VZCGto5pO@%{>#ca6n&M6YKdj{;DU1v#~Q2?fj zGPiDhy?9O5Ov+4QMqB>^pFG%=t)P4Gk-?)WDm*CSE$~pw!ThNeA8gUz!)x^qLf&3Xs%>vKQ$XN9I+uu#pLRjyN^T<4sVAIdklm~irs z7(}Vad8FH~y(q6fim|`1KQ^M|cO?8qRt4GG#EeWk!erf!iU|6xQt8ayHTl6&9rmu` z9BeL|)xR2hcec=zyKBe3fv|b*7Hpo{HwZQ_c=&O5-P~6Uo0k^4VNqT-x60d5^@t;D zjkkU0d-ury_UgYt!N0frWD7!zK9v9MVm!F^)r{i5$KH;PXd+#oJp1M;UR=08a;0ZE zb9Ya6Ii?f6tGEQ4%O0G7&9hD}+}%57`BjVFox9}8-F<3S0Bm0P9X4<7)4=8>vgwwZ z?hH%Q@b0yv^nvo_p2n()N*7`6n-I8zjI#+>)W7Lk0u3OK-}^HZ1L0v@L>dJ9EK1Z93Xh z|D#M<<3ABkaDW1NJc=U7SlzG4$v;XV!!oNw!iJhl^67V*i}bA-k|3ffVi=rU35)a% z7v)WK&No&|bT?K@d?hefOYDbNOMDD_GRdr*WLGuBN;2n{8c3NkVGdN*6GKt9tmW9| z)1qzE&ZJUpZy~(tB2=j`Po<+k#KevZa0#(W2jPzw%K-B50`fRWU9+StD(~?_BJyH? z4+U192Vj{b=_UFEG7W&Bn|?!NXRlF0{}%{t@fKx#d~sI6qA7ldowC)?53CFHOde?r zI_J0O4E1n`>=BCBv^-EzSPb1mK4#1TDS4n`nyb+{r^MY(gHv@NaEAu6Lde)wloe8& zhS04*6+rT+Lcoxm`;>Z9VMH=h zS)6ZBc{q`kcif%_iF#uVIMmh6cU8WqP8_PKces)tglfqpxkeTFmse;kAX;QB{I~Zc zJq8t(L*)^-EDcD7N`j^ej{vnLEk`QNGhDN$oV<=JO0_s)UA@G3&@JuQD#(yWVj@#=R9kpsAWwA4Fp>Q}1{ zR4%`prPMhxl0}RKc_>~c%mUO*PIV@9Bl@I_akKtZ6gOL9jGJY^zG4&uLw{YL$mQ`= zfD)q}TITD(rZ_>hZYp!~Ey+v1{aWPiPY#vVut@#v*?mG!eYlT}gl85HQ}?$Q|6KGA z2xSiqoOtYa_alffe{b|6O_eaO>MoBOS{&;u> zayJmW?fz#0r;UW>(7enPd_lNW98=o{C>0!bg7Qu}-dH~R_~VQ8%McDY620ko%==%{ z5Fn?q{aozX^`_c!cdLFuXGf43gGNZ8Ud!>&4hKazFf78xp5|n=DrEBVt|r~oVksq3 z%E*$kulC6es{h1+=nf|PpOB+LH-2JX9hmT8qm&A<14Fx4=i8+@-z|mv&8tdybU5aG zf0X0%$AR5l$Iq;|iRj%==2pgDSbOwq2t|pHm-l;MuHRO^He&~(bbbNNi+9{GQaYpF zFj4)Tto#QE2hg|_$_<2B=}Ru(g;&}OgqiAOoQHAbt%v1>p;5L3a zMu0yKMoFcetTsg{HVXh1VRGPWh_W9aw@F8}vS%m>M8J`70MwvX7{l9kwnpseD*myW z$N}+nXjDzn`4fPNi56!7^I!_WH^&k1hZ6e$+8G@FP&!2^A}|a74qQT@pF$2K1cMyx z2|?EsWv@Ww;0gkn5!PcEKLHQ#V^hm{jX3N?7fr7wKqUHj5h0Mo&}#`nbG(!g=wj$~ zBk-t=A)X|`yXHL6HLtfzs{`;naH!2DS-7^a#*#cH6PR$Sfw{d=9N8YQ9pXP2M2^)c5PWOQYP)Wc( z2rcVSlu8|lV>{9uwBcjT#|TdBapewgl<*Lj5Z*`nNh?XVXJ~rBWscrDbdmy31Z_g) z`b-!sy|thH(yg_lf^`rq5M8m6aSh8l@b@U)QtJBgY~jLBIMPwdR#)u{Ld3Oi_`M@^*D@Frce!lN%H{w^5VuIfc1U8{VluXF*-Zk={8*9g;N>Gwj)P3_;iAV zE)YbsY=D&UaP?_ruKHHAk4w5>dJTAhLPV@=mz*+0`DMjb`9)K=ur%9G>)8#-xW`Cr z0irTn$UWPgtkh45nq+8Nti)d9ASAB z(v(ee#TWeiGy*4=m#owh3|)SF0w<@-HR9wPXi^FB?gy`SR1*Al+61hbz*B9ds|i7U zf|wBKw$if*fj&V-2qfF+o94vlNBk0rEp(fN57x&CVv9GlDdT|H%QV_X^GApMI<5}? z7~Ab%Uo}S{$s>Nxo|n(V$GxP0zER?^o_H)#iX@fgFgtDnF&(C)Y#k>ZM><2BFUM|| zI(k};GKYRODQhv}8}BS;qR@7Do-@qMz;`G2yXa{$dJch5vy=T!XUOehz*~?o!kijy_rviuL)0=jJ{68Z3#MgFRy@lg07zY;K@=^_v^G0;ODxtXC0#IItH5$ z>R_`c?Fob43Izh+qMwYkc7|&5zH+7E-t=Ws5n8#dlR~xkJHw1W_lDm_ps6kMyP*K# zxOBNOMLg}Uslwm^hH-6n&v%h=Z5EpLD}#9&mTh#n2Uko5n@ckH-I2rirAX0PymFytHz7Uc?Kw zAl=LhMp|L!1?xdpd=IkXk;zZS9wRHh3t92HOUQ~(Jxk9qCTg#&GzfDfEX4=}@7fn2 zWO$J7n&Dm0ET@c4!)SiyHa4>q0!TL}m(DX)0@ou8-EdDu<*5p>mL-y6?v9=IwMvrG z5Y%{k7H1d+c6!KG+RXgywfkASeEkpkJ*S+cJE8jIuH#hx6Kh7t(k9qVn|W_uKs~Cp z-X^kL5a;}4xo6ZjB@O34E|xH8VD)d>iEX(p+F>WT9Wm|hPct~Ljis}#yd?a?NpAgB z#hZ_v;Cx99sW0F5YNV*U7UR~-xDDV#Az9uQTi#Vz-q%|8EuG!9Sofc&&04ssLPkYP znVP7)$VCO$He-v5{GL}2IiFYx)v9?{lb~94P-A;UUQpautCD`69nk!rH(VP1J%6bf zm9Q3AoEBQX3bFByiC;J;fj@)uL`38ea3x&_Ja zVpFI&mc_#gS`_QlY^xI?Ri)hT`76=clCVv?L|g8Nci0=xb|fwRUvKqv<-gvl)2x5J z)uku@^;VBe|BttJ4AwD|j?+KakY{Nvelgh0vtcj=g86wlWkNx9NZ#WUd?6HOXDqkRnF$rY6})WE6MR7} zj}(TAb<4O=totAJS6l!$V@I2BT%!yIY#)N`X^oIQA>?8ch#v&u$5#XhqM!H4Flm$b zi2Idb;(z>m&4Ht94_!d$C;vY_Jmx&ZtWjG0h*ufTA3kSY2Mb$!V97azCprhs^ZTHm z(*Hi zL_WRw!}69b!i%RohK>_G8Y^tOc%iPi9tuomT;6oC-?O3s!1Vv+dVSq*dqxCohhpS{ z8_QzOu3dB-%8eSnzO2In(p!xm}H;=y<_-xf^Xzm~V?9lTAuSTHC=$~>A&1*aS z_YRN1ab4f75FB_ud4~p?*RKD)@?*pO0OQ%+cb)wH$dfT}cIEaxL&rW#-h3X;uK%vp z^3%vI_g+Ku#sL?<>pw8}voS-AgW?kcZ;HPDe!`KeMMnZ*WtH<+PW$uUo(;Y^7;3o0 zo({dSBXA@FK5qVQb@v4y|02Yma7VN6Nm1^0v=DneqU6AW9Yw+xXnt{ELtDtxW6@|K z_NaY5`_AKM?a;h*#~0h*r@Nir0L@#L?)ds5_x$<}XkHd|=+qDAhb#pMp^{tVUsC!0 zAgV+ndeR;H_0myamqPQ7ft} zeZLWheud^+51-l-z%{p|fA_ypwbAyt>#=Uzi~gQ}i9{J#9>+iwv@uIKN|tJd5vJO|Bn zH~%{I%ZAOZFU{ePny^!&m-Mm!Y}z_`3@YCoR>8 z!Q8N7ZNs2-)k9HjofjYX1rFM@(gi(R&dsXJ)vqRAhvwMxKd;{{aXS7Qn(GJ6`jB>V z_sPGZd8PX|3-{kCI|0olR%fsLb~7-v9h#HQjQI3*TXb&d6+jOzYp8zn@!^L1(7bp= zZvVBF{!0KosJw7B$IrK$ZcInN7@6gSpf7sXK+CH=gm7^lz2G2P=>l#bRXrLJl=MiS{g`u!hiFT^H zh-K-CY1YBkZwVVL)PePiu-T{f(h?}9iFv);MA-DVye2$l0CM%i#^Cm0-ZT**fSiB` zLQpVGN(gi_>El6QT0Rv|j>0!fcSM2V{+9v_cdCv3gcf!4nC5UIzR@xsx2bauChbu` zBq5JHE)$?|0t~_@B1RMtKNU#E0t`m!-YDBs7^P}usND5Xbs|dT*G)!-dxoC(XEzXxEmpi;yN(gT?^tUu@ROd2*dr2 zgZ*T}!Ug=HmH-Ww^k$(-7XPQ=e)5g2k}$HXrejsO*{p+_+k`+nT}%ihP;;9Q)K8ZY z0v*)co&<*br3maQpJ>rC{l6IQ=N(3e-HIR#_auA0Ijy5vDfl@erq^)SEu#;Ud^Lx7 zEZ$sl>^}|n#|t|mjD~yT?-4a_lY=^On(1}F_NFS8>~x!rfTz02;IL`1ot+T0blW(d zktV@Q7;S<{+I#xrBlvT|6;r=cx@xCmAE(ON2->gCsW$iu#=;AwqMt^dALTcf`|}Ea_E_a zASXde2y{90Y(kKiAg_w_NGEQ~m+zg9Uq_y-Y$NE;gnWNOke5tuCj@@{AVN@@tZe6J z(c;PA8Ug1+IAePkTH%=2x@7;tYw~c+@iTCT#Ym1`Ta1qu<<=h~F@o=6D4zv%dvfIL zSq#){y@Fsqj}QPgn+I{kaK3;L05w~$U<4l{1e*C0PvVGCd?6vw&e!z{#_&aiz%XCp zMI14fA3_L7DY{-k5?@RRuoQ_mam0Ln6d@3&=)4^rq(k~(X`jA!ht=I2nAfP&=s&r8 z5ScrRraO{}wQb2vOi#S(Z|GD)P&!jg2qgRHbV5)yQ$`4M`{+zUpq;581d{!93qnvo zQ$q-J`{^7)pr5HF1d?y*mV}^rrhyOuxwkbTFw7*~!f#h{fNnzwI%Z;oKzD#{M+ksJ zju1!=(jDf+%YN`D3-FowI%I(BmC`;(0!icUv#P8}oC@)dp zdysUoa!2n$N9kjG5BiR;=s& z37ij5UFO^n-PDHBWY zvx@KsHl+8e+l#aQgAY)&Y?ZD42YhC|M5la3M z?B4tjK0wV1Mv0LRuxn09Q0L}}1Rr44>|Q=V@=`M&VE#nWb^Ti1w-alop9elb133D^ zuF22P9>GMIrkc@d3mFOmaOAF#SUmEVpsgn!7BbN!44gUJwu!?83!p?YX%N^${Kep- z*}z3^jVmG9sD9@o#v0_HdBFf*&5<1BpxMjDDt^<;#}oiZO#5U@E>)c>@N%NdDF zMUu-|rOUZOmkTDIN}E5z2lp*?8pdjAn&?)B@v-;g737kiBgoAx&8mup!I6R-WE?AJu@ zS1NHaTw^ejFvff_V_}~A0jK?B=VGPv3Efl1G7@t+#{69C^|6)t3JN9}tUB_n-jmrM zUNX;MoReC1Nglh$hFvUWACa?XVM?rDC>;tR73Pp2)2{ikP0en9oYBI`vjvJnJ!G)}A)DkSg<+!rD98+GmRO zJB{@Rt@Ymq>(C*#fn>YEt+sw;wmwO={gt)@bhZIWp58CLhA8cZOYDZm*$vxiH@4Ld zoX{i7>}0mSG8=xfCx5<#zcA2uK{CI;%F}O`=Vx)g166$gUHp_4zNx!>7nkt^s`+WH z{LdwhD{LH>ksVhhIewvZTwCb4k?fpt%kfK{W08&XA&K*`oz91ooWD~#m%em9u5&)w z>b&^0@6y}8>3{mlZTn|bknAWjQ%E}L8*nyF`lclJ6u(7+C>dSrioEOFgShrFQzGX~ z&rVCcwqdS7A#uth_uHxz6qB8sbWR1(=MtAvolBL(RjYHY*Ub&ks$mq7Lbxe$w#Az@ z*LEJsJ%57xpnUs>iz1Rbkf$Dkse{Dop(6D#nL1cX9kqr!Jck-mKpjy+9jQM~U#9SG z@L)`;A?XBpN(0wAj*%i~q>>`R>Be7uSa)RM+-nYtOBr?77&0CsO)-Px=Rq<@;j`De z3U?j3kSMhvn_&;&DHHDu`uQi-M3%xFkxz1C<`y6Qo*y4ZGRNh?K_>N%Sy`d6#k&Ih zZc5KZ-yElZ#Uq(w^pk2vTiH+8Jv*QTQY=i7dhicwqP}TcEOB6I#RBFO?}qB28`Q|# zv>BZ5#pYY$A2B){kF7XPpMjZA#kphhHkSRQ3Ie|*iPA3KRks%Nox(SH`qJ7*{y|lX zcD|SP7590ZxR>YxPo_ad}K_Xz3!Tw`=5dCuWyOA9CG}!uxZnV zmTO9Pi_?pm7cNX9G3I5g+%H&@DrYPdGv><}OG(a!WsJoJ$BjD~>u)*AWzOHdWTcCo z58q-G1vL7KNW?BIbw4a)7c1CD<>-6) zUhjOY6uvk1$x?QuoKq&^lw+Im1|ooDr2 zYxP9y{#5Exxzp;I!U}4e{oZBOWw3fFcd3fAek-=_&a?hYYyClQ{dcST`#hKPt=6AN zwjN-zNsOOQnajmnE}zP54K{9VWKX}Hw!&7IE3LM_Cb>PmW$Pd4_SnXAu!pB8&@)it zIRri~rCTt3Wl44;ciN3EvkSdtH+HAnvp7!~%se*y-&@_1lRW!}@#n|!KTG1L7V;OB z@t5A>r;{C***FS49miA{tSL!=Sui0?*23!scW`O+gQ^=Jv>3Nm#sy$~qXCffmQ6Ct zW{qWw!P0?8Ewk>{SVtr&1*Bdm&MO#`ad)#8epR{r+rN|&@kWDKj zjPdlbSsSmK8ap&uhOB&kbg)idqlX1nIB)iyQpoay_c*l8Ng?Yw`5b`H=E z()Og04hv>TAktz=fggN_9?4K9Ov}_z{_yL-WTRp}+BYZ)#h3*o zLlQ?;2I3Y#u_s^n{r7h^TFl8z@?zfzF-3FCpYMqi{1ahG+{{v;#fqS2Vj-it{wQFbg5bvBF9{=&!r<4AHs1I>+rTz*7n_pf1${yY zQTu2J%A>I~Qrbb8cJv37cCgR**X?c;f07h#g>&GZkk{4mr2qja%>!p1MtkEwB-w<5 zPz;vPcTCqqC=;TUn9_*^)uU`ev~qCno#EC(NeNbykxW8L12W1w2lgywu|w3Cz{5Ty zl0Nl@kD{0ijbjhcM4iwo;spE${*o{f31PkPZ-)ol=Kk;^L6$a)C4utxSst#VS9L*o z`^RDg_#*!!o83gE6tIni5MbyJ?~6CLE?*zB- z5?}(A;!#Z7Q~nSX7*1= zu5ZO(GWq&e3H<=hi=>SPOqvnV)vu8ohi~e)inzmFljQJ%={DTq_Q_}~(r&*X2!IVX z_Gh?tHm**vtw!7RHcc*7KwKyKf&s`m4GpG&BHAX{o>U54v#{dLD^Q!`|5`kvQ0v^E|V$}GDGa>Jy27`xt=x85U>#u`0tOK{FC9$DD@C$cu3&3ZC zOe=iuB#oS4c~s-T8kHym${h|i%#_xv$L!!wzDhmgSD%0j^FK&(HwZW(k zoKt~B)FM^XK-cWoxzLza@Rp^ttNv8yVyI@JhZ$v@v~fs@#F_S!RVSch(&fLRofE;d z1Iwue7TY+oTwN@g3rY}81J0xl$jj8ch zyJiBxmuXTljEbW_QQ&uM@S3Ku3TY9&MpY9V$^B>x53BJm%$qDFkCOEx$7`*i_5l+v zGlaPsgI{JKGauk{-iHzFYL$#BS@jabPuOV}FSbsIAvx%w-d!{F<)J3>0KE=IOA}e3 z*9~BQsjl9muude0F138Gap06*I2680b(fJy;~xCF3B`dvV@-6nNT*Q zW=t3)zr~B}W<NRIwGdwKA$4dZS!+*>_Y6GbWI)wp+8 z9(EM&BKg;LF?-U!LD#l2+`;V(q&JXmLACmC!*YTLtl2TnGKb(8!a} z58~fR?N&ku>3elhlTim5HIGpT%{S?wmOeUYb{`#N?|KpGAf$J29rSfy9i-{2gFXWt zbi$;ADv%B;_)i^_gmh5)|4auhL^@~%u7mRaO$X%>I*9z>UvR}sT<|~`_2&Sc_w}!W1kc4$zuarnB1ET07$S`9E zHV#=cPrAt_xR_!_dQLDRJtxkXII!NI4F7hSJlD>G(wd>t?;MZ-@$>C^4K+&`k{?OK)VoK@xpGJ*gC{~$E2}S=x-H3gtP=Ec6GlNiS+K6db0*6 z|4nn~px(6+)8B*j4vOcNaJ7uNN(Ap_wfHB<@0{X-s!%c4^5Xxp#_b60SP9)z`4dg5 z1lA;p8}udHEw8#CDg@=G(DmD2R61O^)WWyWp`iAz!ZwjS@x++%EG|1tG)aS0)h^~# zcW67#Eqa4A%*GJ{=ca1yN)FwjDH9e7DM*`ssbxwW*p0SbI&qd~C;1T5fhMzy7?^OM zCY4kz?4qG(T_<7>ieUP%KpksUqU&<@VYK=zgI7~P@q-bNxd;tizx`p-I)4b=_`%Sv z9uRWGkdH6Pbgn1I(bpB?<8M_^KCjE%6%scGeCV#XXPu!Pg%{IN>&hyDE`&R>OsG}n zvCq(s!|swgC)Tk%U_d63#()kuk_LKl<1DDN@Ox<5&qyy;E_bxoDWOb~Bjiv%0K6qB zBdwD@o*H-K@ttw=+shrL!xnf|c`- zlDzW-Et!rwvVaASuG;tjlw|u`P?G1IoLQICe_=$yStCA_FG4D_o6d+B0arYtH2k$X z73;om*qNIV<8oGKxL2{wg`1J~;5=+r9fi#e`@XZ~ZlRp;fj=eU;LxbGOc zFRGD}89BT_ryw2b0M8CO)L3b^{65m5txcdqrHhdceZN%cNu#0+w~NgXG)h1!bpS+w zK&5KT8E!=t2A)DAV11id%G2^#6O@ASX02+{tE(tzZD-P}K5(txwO!u@{r~H=9idmP zvch^%m_`?h%K)Q;02X-R07pcBVX!*P3+vUcxNh~Rz?CMkz7tLPMINsGr*8Fm1I_yV zQ@56*m78f2t%`Pn0g2WnsdAm9*$Ht;nK&N`_6CR zkI_e=F>CT~6OYpO`WWEf14fFSaa8K+H2@EwaO2wtNEAT6WSu#aF@&c^wE)$!JxSaK zOxP9x^bpcbKo0>y8&*ROM_6o~yhp=jjC59Viv#m&Wtw(Tz1W#Lfs6Ob%3dhhw>T_8 z@kK`=Wx1Zs!il*A?w%!T*K*Yi16$^iV$ldZ6J<>10N`>74|Od9=ARp|=Or^+1bb>i z1=ZYa%bX1GQtXjAlk=Ml80^~@;SOoLCz3+0+O^Y$6ZZBe_ zx0eK7FJ&y{xlY9W7HL|B$50Kv5%MXLt1#RQIb?A*_}4cADnZ@A}D2U zb&KbT1NyIU5XdIYF)Mk>ZCuIcR)Uh(1oTnzB^+GIYYu5B7qz)tu+Cm3FZD7j`GReI zm3;k9qmu9VUsCcpbfn~0h}antQ1a61C0%m0m{aXvGaw&AOQNPQ(RCTSgx4hVg}krK z@$g9dLG4e|Sj6w*1Wr=0sr|Oe+4%#chei73g6~LsDe;7*S(fkS(2|W?nvJ$`VOspX zon3%?%Q~iDO=T{Lx{-2vCzxi)0;sy`ij_duDsH;+%$C5mIkn`SU=c9iq2C$?Z3B&p zj9M@q*Y776vWIjkcFXSG9yOZ`Tb7a8abTmmyERSrX3Fe-vVp#ZA|sN_TII3`8nKx z$VhvM8xYG}HsH3#qwJ;9=-}V)@Qn{f3QgNfC(FK}OJw`#O6W49U9lQ|Ic!XjxgH@% zE<@h9p@na+Byn4j{V)&h{2nSSkV`7)xjZfjQs0wPzzz|dfwjB>Rc&04j1A994U|h{ zmC+}f(&JAjqBbY#axEi`gxtw>h5k`(qociIqzmT6XDz&6x9TC2kz@jb$QYj~7>hKF zB?iXl6&k9FbG=+*1~)iA-K?2VrY@Y5 z)HR(`AJ%_0#bDtfVke`xFG>SS-8JZV(XfD;&{OSoVs15wdqc}@P;l>xEpL-7?<*`H zYAp?#{(X~s{uEok)%MR&S-;m>Yul{H0Sl1CK2~WzQD;A$>@Y>*Fjwi2sB=hH@)zs) zDP(@KgrA{wT&HtfO?F%%alD~)ZP2;akzK1Lt`C*{?(6#9CHK26>Gy=})+TX#p>%tu zb9+m6|5M`LqjZ0-b00_X4h!>+Re4XW_Lh43#C!VJCi^5%e3OfOX1Dn)4)aY>`L3w; zO{es~yI4640-6GPnb5+7SAb-uteq&OfAwO?5Ry{P%PO48`G$f~k`lPTig)b?bI$me zgpo*VVapC?Bjkrcae@EL*B7&Wz?X0Gyi1JnYAUq*PS`nsJ<7DoN|wep2EZ~7<#==F z$$1b@v4d$j44=CNC#J591C7Cw#3^fyDd3_kSdEEjS#v~_#6Pt!r6W&=156``2wnV@ zH>|znV6Ul{)4mSIS4g@a*5DI`AF>2>OohJsU9Y&kri?8Gw_Y%5TXHwLZ0 zWuBomh&Oq9uUB;G0$cJJslbfu9|MKo7|F2w7R5-_PN27j>I+%i(^!KXmZ(bbsEdJ7 z)un(O2*qk8mW#pnr-yTl&QS-|SS}|qtNp`+=rxurG{}1$L=P?je@Qg8UoPfhLz#nC zvQ6R%HB8k=-U1nGDmd>@|Kk2H{%kZS&021plYXQlAD)jR%t;SWUK(bgIL zV=l1wbs*w2pWa2LyNKCdMzfxT@Cc}4Df)9?DOHiL3Z&KKtFBV}Vkx##x}sM)cBM4E zSjzKRa$`R*2J9aZ&OAJh9T`Xc)pYPDocN*p6I>QD>z(*8CUj1;Vsa+OYC)BUUJHbm zK$1#j^$U@bc!NUnK+jRZPE2OykXVkc_44p87zEUvaCad$lJ;l>CEYa7K%rSCV?caP z=Bi`xSgQ+sgeewzN_M2f_V_A>U+9yBtRVCBsSnb=wcd>X@PsXVv?3qE(a zAQe3S5NR_;Q<&?76T_I|DuWmd)nKG{>nx(krW?>R7J-{5gd9w2RZ+_fmSUOGHB+|7 zP{_h&%K;y86Wl>Ts-ckljC@SuDWVokF9G*-#1xsjU#3!Kkc0eN&LX)NQ{cjWm%Ivd;_5C|>+;z*}qgPKVcB_wCx_m>Qp7Us|$ z84EGnKB!TUV_@w;o)eR9Uw;8rBQA;2M(}H|f7bULXv8Wq~-UozCWaC3KH3vS=v$HSGB&Aq9 z21++8cqL@3dJKN8L<;Ke&G1~&#es0AXJ{)S7*L{f?YIVRHP=)T^ZjzTeWr6Gc_X}O zc(i!o8X1Fu%&C9=4rh6ynp0r9f=doHdy>SAf}Y@dqup!rU^ntE3H?>f+J5m;>qNQk8bT&0=PnnyC#Z+)1hd5=ndFu`%)x6c7jY}Vl!}mSigKs^*yU5(` zmG3~nJ$0D1u(EdKMm>5c{FeP+fPwLZQ`i@8>vB8!XhloOyOp5tGwd_ zi&qzNt9)mEqGdI7I(o283MHpUPU+ff_a6*0NHcCsQeA?IpBP(yXdg4U@VU(869M9VFx zn~?wO&j=jz7)|rWza$J|9k69blNjv^gdgY=u}762QNTY8cB7aGMI|#Zo(~~>FW2!P zrG8L^W#)7YCpafx0yKuH4vnrGFASHX97sooKC)?$Zbg6Sy`ukvomoG?xX0wj;u!Vc z|3CNSD3qdbo$)^U-q&yFz9LZSvXM?IYteoHucAEj>*d>?U$o{{i@A0GgdP3^k*KwL zL9#}sr~=ucJtSm=jx!peu|^{lDEYt$ooh5dQ;eo4U?0K!Tx~Q(GmPeEozWD%VKhJQ z8corMWMqUsF&d#Sj7I2NqY>IeMn>p3qY)ZQ@tG3lGhOAASnV?x>`SQ9*#IxIPSSc-bs;+kRU)Zr__hp$!-Usp3c zgF3o8d~}_9bVJSP8`Lqk!^hlJkGWqn<{>q-Ej;vzI`mmh=nLxDKf}kqRgZmNGq#5s z5f%|KE;C|cZA2_Ba!N$x^vuY_+Q_-Iamf+mQZmOat{s<78^0oA{OZi{>uSem(4wm& zqU$oF8)~C(&|+>!#N5q{xnCRekQUn(5&I-F_E~N03)8G#@BR6F?(g~c zeSGfU@Avh1^q-OAdtKM{d|t1&558zk(Tz2dMZVQHe5H&0ns4}#i~Kuo_{$ar^xO!L zFA5yG5h!xrnAWrra}LdJ3dNl_6*X;=IEPg?g-M+^H#co2ox?kt!e!18JxvjE=g6U^ zNYUbLX*aiFi{rCz#^Z~(7v0=0S)5RPGeNp|NAt}cZPdVM8}&*dLIL6+s)QfE-RhiJPyJfd#LKAKY4v2P^2Z_S)5J~Y7tgno znb7AvlT!ufM4PtRl(jK}O_kQ$so5L-_e(n0O5LLWcxcEh{=lR2?@#6${0wdqE;jgi zN&_{p+-ypfRxK)r8X#}F7KNpLf3;iK)OKh-m6wX6M;3Av%=mMfgga-X2$69Uas2E) zO8ZWQw508vb3;Z2OKQYYEXZ@*wD`kU{R~n+-^o0hUl?lm@1<}m(j(`G6#kC`;=lFy zpRY$v2mfKGD)`LofO*`lLELQwcgIVvl&5_?NV}QPzWY)eI_b1~D}Q&R)3o(E9YI0I zyVH^MV`uMNyVBcsX`eA*ZVH%9PBEhamfk6rOMp526bou415dFMQR^8g_GF5U1hA`4 zv16z5Wq@N(iX$!64M}x%N_96)-9~#dpQ8}BBsGr0pV{#CqTL=FLC1njB37ZuLZ!-qq;m!kC4 zw+2wzJyi8W^&wONL=`<+)OrayUIu9vQI*iRldCJIwA^cyrK{6HiMsQtqQ?-Y>Yi~+ z*C34&b;pYy6OyWX#>*h%YN`?v&)Wyg+X+~|R-g{YS?|ZsN82la?M1-u5MXmE#rk-P z%{jpSVT#?|6#MHb)Pd;m8gP8J(D7S}_fVrZw#X;@hR;8#geI*y1+phETwXn?Z(3@U zbkacsv-0Li7Al)~o{Br@<6nEKlRD$tylK=lsVSZcl`&uO)adqQYL0YWpk&iBLJg6q z{Q+vAMC}s%LL1*7pk_+cF2QGndU7kj$H$-IKT{FWDC-+O{uEz9jE|C5^ElKf$-9BG z<|RkS;|h3cTj#5V1aUW_+-sZA*|bV(k%{`O=BqCj8!T^}w`$T-_g$gBi+sG)MX36| z`8VSOHE~K_ruKPIjpNkz4yt{Z+T@``W>oiegDyEf1lN(F20JN=i|T!*zSUyGL)Q#3 zvC*k(Mz{j?{WQL!FzdB~$#wB8s>FP+FjKDJJ<}9KRDTXq70*ZZ{j!9e>Z56ROc%a( z{@h8EvJuPu`YC1xDW){Q%o~_%3e1_8GKbpy5dc;jQ>^e5Yb?c9l42tT?3+{UWho9+ z&NdBn8!85)4E}fjO#Yd~{jXCc|MABk?UvYov?tRh9j%n}XXMJ~9rvEgR(AE=>yod0F?8>R$fYN(r3Z6) zo89sjcloWTN}4;?KZWG=luExmG=k3%gVMXQF=?hj$BhO_Sv zCVv2ENiQQmK{C>?D~l`O0reglcOnXVc+tr~=@W^t=#xvN%9 zt5$=n_N`X!R#%OL>MZC<9kbxxZ#;ciphW>fBNThnIS;BI%T z&92qmzO&8#rMtu1HitfU$Kf_dboG+-hf6Y7FU@(llvus2_~Ekh)yr!hE^k=9;?~0z zt*cjdK3w^7waeRwE`6(44L@9kdU&Tl^3L?|$$8{Mc&sUYw5Hs{x8{*=gNNU(M}Dmy z{+*BfUwQ<*eH75=5jgxP5cS-c-o7!@Gc>0?l(arz4&4-c#JMr`&sQO~>8_?|rvA_O*KN@9fzB()+;M zjstz(8N(eJs84o!XLhDfPEKbI;d89G^H{l0ZcS%ygU|6>oyS{!@;W>7UizGP+j*kT zCx5szA6-+F{-h{#O>xeXVq(qN;wNXRIN6#fB@JuN-FkAabxmpKlhT)K&cA(fzHd$0 z@RKssw>tf4b*67k&eIyg_j2*m%jLecHBW0Be6QSkdZpF3uJdW#OW&(+pI+_rtsj0` zkNP#IKWonPyOs0o7U6fh_}T4pzdJS0?lkz_z4h#FtKYrOXZK$EwY+`S(&u-7_}P8b zza#y5N2Y&g&ht*f|4H%lC*}T6Yo0%C@PBse`LkC4=bg`=zx3~V`@E~q|Hbh07id6F zdRI?oz}uXzw?x2i#a+LZ2fVB4de;!}{#Muf)_@P4T_0Ws^uF!t?F;xg-1QL+97=yN zlo>dj^J16?{963tYkA;E&5MzSz;Cx+d}|H--udGD%fQjMFGl+U$A({wp#ruToeb%# zV*e3+gZd4nJRwsk_yHOn| zPF)s+?M8K>1a)~3xf|7klGR0PsiiLis6ZWC%in_p^hO;I z#txu1&^vWWFnIvAg+8iFgV`CV9n`N*2JTzvo)xtjVyEpWT8Q><7c5OAuA2pIxGuy zfowJ8>&Ps0738QPTF*X$dPBh)*n0jE)Cbz2fv?AoplhHo4as`)2Tei5h~EYLioqfP^eG?55bP1o1hX6 zNeFoi4TH)wq#^8FbTd?;L5A>i(Qv3rLl%PNq7hK7hCGDKMI)hl4bcYn33MCuN(0-# zKY_+WZ#3`?*a>ty^iD&vfjogGKp!=v8`$~i4ya#)+`!LA6QMy3*#<11%3&YTkZ&OK z(PU^$L$r~75>12jHL;ESlW00LLlfVKokaIQ#+s6i~!LC-_An(|F#30el#Yl_0y z=h15Dl_nO(KabWxZ#3~R>^yoIdZ#G~BhRC?&__*a7`qI;0`+T>Vf-?*4jR;yg<)mr zRcJ(09!8d-_0X86Xfyi)+6?JyVVn6E&|A<9EqpU}0lf_wYe_bf7tlMBfGgg7NK(<=)&1407A9BK2bF0_BiOa*d#FN-jNsRzAD}8NSp-&#_CmE<@(8jP z{Rq`-i6Yro(IM!S78c3BiVj0>wD3smD*6?ArzMFbuc9N+M=fb2yB_@p^=pxl{Ce~| zG^izu#Ol#eXhcgMN!Fud(3qA;$i9Y7hV`{EA^#dW1)iae3$be`12)!{2+3=xB0O7L zN)by`3AWHCh5Sa830rB)gjgf04BKkUg=8bjf*rL*QS2sE9S+vUqWDdy2E0KVkHVTz zO*l+j5=AzlT5zPcG>Ux_)rOr=sft9HWtlqL9O98+ITe9g4)3Ev?bAG3u+60)RsoG@1u5bzcv}o zzmM9(gW9ra>^|xMk7&!I$@{1yJf0bE`yDABwNV` z=yG_rj&v)#4P60S=#X3aZRkqaN=LR8YeQXNTOIjUvJG7YJL-sH*zKq{9IS)I@Y_)z zc!LfegSDe;;4mFY4B3wQ!jU@C81`e-502I$WB89ze>hG@7K1%T1KP6!LM|%IDQuz55LjD;Tj&_sApM;3>@K$GAR9eEu20!@a;bVS?Ozo2QbzAmdLo~uh0zGQCAet?nblW zU|lSp-;L(L8+7q_tQ$QBhv`b<$!;_kj?|ULv)`b{;b>hlp8p2TgX45%@z@*m1e~BN zk0;-t`EatXXgm8iv)GF|C*_B-?(T%k*D z=f6Ws;VNC(cI+K`9QA!u>^iES_8k)#S^ez^fLTTSCT;X zqP6fxU1B)9r{b&npt0&(<_M`V< zM?Fy@djRc#gY~dP{s7tuZ_vXNu>tf69Hu8pBnQx^aHO6zkv)h$gQNAxME)T99FEhI zC1QhU7o4CcPb3G?7jUwkD2e?Q?STvQuq6Ih^etSdhbLiQ(cj<_JxLPz6@3Sn=}D8= zBj|g$LXS-1kDwpmDm_^eHiGuTwR-X-as>Sd*XxOr*`w$X{7Mf?=8vMo@Ebim85>2v z!teAX$>bl8}Kl}7lK0DPFNp+*{8=@`= z;-`8DQay#K6Z4-`A3Sx9B-K}%>PM#f%Tfd6se!cJLiX+`WcL=+-O&?MsTjfTSmEwC z(e7|4-3+d2-A;>(zCJj z96bG)Bt2J}ew<9tlck@Kr{~l5O4xhLk-Zm8_f|OVy~y8NDcE~SxVK8Qw;J19gYUg8 z*;^~!dxhLvC);~fzPFyXpJeZEMfN{1-QVW4{~>?>Bf>{TN7^6+iUxO@L|A(IBB!R9`I+K$sWjhMl@D|EskmtfxZ`0eXoFi*HiuOg8maRB>}Hf1KxvypHlrG4kygT_DoVJ2DeY*aF>3>^0 zoyy8FJURW*9=^$g>9e-frGffp)AcPD=+9lNZza~ZKB;eeP2XPpeRB5B^z408a)y@t zj}I=?j82*xopLiev&E?BpwZb2M&}+F9V@xQICp)|xx0+ghkHt&G0wl;)*RYfb6oNAslAuaDb`-xTYE+E%Jsch?kd(j+*`MKP&|A1K*m6UPWoUtASfge5E6Yg0DoWofdZATJuvJ`&ReXU}LZemUE30I{hS0Yu zSZH%P*rqVWrntbSq|v7Il}#C7Tdr?gvCy_M*tROgwx+WPacEMdRea?cPRxkMb z)&&NK78qSzVDfOmtj`NfRTrAiSvY6)Ld&fS=N(#Tb8(^F!-Wo?7YbCJ*UoWXx7s;m ztMkS~&YNB*hE0Nkq|V_#LBTTT$VrQ%epffl84P%2QmE>fb1=D$nh?xHDEmpD?l zWm1%snzB-qlwzg-ZBR-ndE3V+K6PoI_1vJzqH&Hh{uYE#|0oEcGHKZa{~4e6|3^mS zL^*MP1|3fpPc~F#Ppq&z5r@VW<~dTY`Gdv0zWD6H4;CYpPq3IdKUmBnY8Ui(VDb-O zL*Wi3%4WkTCB$`tw0JOm-6$FM&y;~?^*OXryL`6(t1B? zlFXgcK?Mw~oml!k=*Kw%@~ZKMkSPbzBM_6?5U5tLp7aw%I%pF~HRG(Kn+oHg^s@ed|EMn8Pe}r=|7T1#xFzSPXt+rJakpi1x0iFNc&Vg5E zDBT_^r-=$_qO#%obmdbgQm$O}2BP$aGRIH;Z>@TxeR{~W>C?@p|Kv9P?^~uD9Gq@+ zVfy&7|7p7EG<|dPiA1?A`tuIz+g#AMd!X;|NuSDATWkLJb#8zEJwxp;IcQtVXKZzw z5xZr^wu3XaUzoAu!HlF&Gw^BuL`mllkn=ZnFor%pN*8`|E|--yX~!{WKd< zF`sT>{*$};-=ob9vdoPt%uU+NXZ4$#s#us?Sj=&^u#C2tmt|p7VPV&1;m~g(P_bNV zVY$xTG9=n^W0vKn3d_xHmJ$7yLKUkm7FJu`tzx6Cwq;puudv$DW|h=$g{#<{w6Hnl zZgVEurYOtiY=zCaHklRx4ym0Pc z4}8@biv0hI$nuW}E$KgUwW!+?sKfc+JJtUm5LvFx9#xpLVBx>fS^lHQvNhOhTY=S% zS5{JeyX&tkh5AmVx-fiADQTk5EZ9lZOaqOGo@NC?Mahx#We%Np0 zPTRQLwsWEYUNmmJjjL}zDE7}Q#$~wcGGohp3q;$J(wTm@uD>pKE}a`(s#mi z`Qf-6ny8R>PE^PfTFitB({y*d^rdLs?(NkSu!AWHW?Y5&A;C)OVcPz2kk`EL z_3}Sh-qJq-*eW&A})ybXX zjBoF%nR2G2NKyK`4JM`@fW+)?Yljq1()Lr09@V=wdLNpkvHxbysmas#hnQqPKorj! zDYi*w-8H#TL0b0qQ=FX=eyF}6SM@u^ypmd*kX(?^qikBpS;}F1~3|3YH z_8ovjPs$}J;64vry(E=#9>#lHp`y)?mbM7+@CH3Mg4Jn?);}8CZF{Y;i9VNV9!XF* z8IDqiYdXMy9x!kSe1yO;-f-;3Ke}CSX52P-TN)fs^_b(b1e{R4dxwy@BU|zEHf9Ia z^QJMAdf-m^?#?0j2`!!KJ+W%$v*!Jo>?u_5c~G{W>VgmR(+@eNXKkEvByCFewkb#d z+6Bwf^T>%VxQ9aLGymGA{V2nzX6>gXzbjb=F2P*I=s%v?eNKrlxm12AG1Z4Ncj?LU zP>tCAl$G_D_U+=X$BDl`!5i-V(Y}f1tN*owOPl`>Pir52`zcuC{|Qy=|Hs^!GCh^G zCJ?jcH)c=v{BZ_dM1OmH{_hz26#CnthTm}dyNeSH`Xc(_*#RwC%eX5{j!PXn9dzAi z^L5{&x}WFwQqIf}+WTJl>F!xDmHmPd&*&KN*6p@}_E@&eSZW$n{jM%1o`gsoFO( zweL;Ud6cR1e5&rRnYtgQ>V3)7`##km^`OE2X@-Xn8s<$iI(^Wn)Xi8zF>A4L-87RM z2TkrxoB8PA%<^mh>)5(M?0Dmv@i&zvsfU*AS6zDe(9%5BWv34WX^` zGq10n*(_dhPj%&^Lo1)Fy8L>`<%8;~FNaorKQvSH2f}`NZD#iWrP%+MV*mFk_N;1w zTufJj%+#9&RJycj3Y`hfRqqf8Q|QW&wR(?0l0s)e_Uc0dc?w+}3Q|uCGTlYjfYz&L z2MKr4HK9=Tq9DmGx)v0!UL7RgMc0O+)SH7$chhyC81;@I;cmJv6tCVBB-u^ZgA&z; zg5JitTlSpNwVmxAbX9W zb@D8_Hx#6iw%+sz-3MB)k-c7cguVs})hJpoIYRe^!ZoVb%a73gpeyWuK3L9$W=E*7kbVM+*X#+E6rwhNq1!^%(0?A_ekUxWUxkJ>d%`4T^BsDCdg#06P?)@o z-VE`y(l(o3px=W2rj@-}c!7Q!GSn*CEV)3x1I^T`-YmaBzYCdZHE%Yppx=Y$YISTD zR?u4@YptHmk_!5L$X;t`v%G@d0R?HLg_~ZYcS7s6vcrX!=ue~l_a*eJC&(y9K%CFItU^8mB(zKDzgy(8^2!)MwW!PG~M<{8evtWDeA)&mHt_}xj zr$w1I(KX=p+SyUUCb}jZs$CQ%X`*Yv;o8+va%z(qj?!+9GQCOHfn&5gqJ%f;x^TR9 zPn6^)T@Oyw9*UCRq#MA5cG?!x+jK+tw08Cu;cdFjUwAjQtMD(1ybl2@U^AWOt)^}C zmGB3d@?4#ct-?0C3v8{^vsKbYUj^Ii3~iOS(Y@gyowOL!cDfI|UMD+7*iK&qhw2o? zNZRSXaJWu&jJ%!h2S@2N$Cy5*`@=Ch9Wlbk^Z+vix>I}umAJaF&giczl z=@WV=d|D?vR``Uz2`<(tij_Q}hry*f)v@v?^v!U&PIIj3Q+haDsnZcFd`gdiYjk>I zB~R&*aQ9!pcLRm;{?pj^XW#&At=qFr@`|1TQ&I2sxW9sc!Xl`ET@daJg>tcGGwCQn*sL zW4rJj{XATw+p}Hrj$Q`W=?-m|zoS>fFLl!rOnd1y@UOZ0NNV zUQeQAkp2Qr)Ei2a57K*JLN6`J^eg=>d|EF%N%)oi8(geclqC5|e+QT9RVT^6(%-}7 zdd*3uBlHh&rCvvpaD?6q*XZ>mNk-@&;X1vcB>4z^2!5%TmTWpoABKO`!*>1!g;Q5w z{A<$i1-7AFklb{ zX(vthC>SzM8)Tmp?olve6dM$slFJl-RN=}gn#mgD94W*~pS>hE83q$f0KTEulLG5u_8Oly!Sz;H4t)cuB znI&Gua5SW_ujn-Uh}fGEY>1ua9})X7HW=cku_NL&j4(sVY4V8Jml0_wJ;myVhJ_JX z#4ixY zS#dF=&Gu|0VO33r#TE<5sX$iYbe1*|(M3(T& z#C428BUuSnCcerTF_M>%W#W3qn33ol`+~Tcp>K?x<6jWpV$3kc&tVtDw;9I9l5^w* z@g2r&W9d0|h4?PR!k9eAuMpp3SQ*RCVHM&QhOM#u99bd0&u}yrm9j61I~c*nSSkOK zxRbHL7%#;xiJvgSj3uSyCGk^6q_MP=T_t|Th&Com`BmcQ42!5Lg)TTR#%bg1^TNvtzcGr9i_S|fE4*Wr8dsl}Usia}C^v3C zZ(6JHfl+DPab8%f(95VX?l~{1RrtuLGafoGuT>agyfjWLGrg)X%=pzfyG(dh;Va{< zaZ#D%s=^54gK>44{Hnq?MxSwWnQ6Vkcg7dvjxu4r!YE_dxTj1~uQ0~=Zah>bP#H58 zN!ZuKlNI$%sP!dG!oMb-qBz3@mtfb#3`JuTiG;i+R#cpABBdg2#Y&16CZvSlC}t{J znY^B)tY~f0Bat*JuoNk#W^XbikvA%+D+ZaQm76vxXeh2X$u1W*DQGH&niQ2wniRAY z!%eEo~p5T1PR?L{^U76zeJ`n8?e?n_@l1WRq`y(KAFB z*tf+7iUlUv1^#WZp<XU+oUOsbBBG@R+IuYD>FJyL5J~` zD3j`ICd~^>DD==ltw&v(DZe(WN5O_#v;QWhR?o&)?)RnGQET>h0o!vaHq^@f#qmWm z$A{F8-9iUhqdf&3erP%P^`+TAFU77e`L|jQ_1CfT?#>m@cDlTtFy`SWL z2=v*PO369nT29fT-?bd+4d3QPe&h|mhhV_7RR5ku0q++DQu>vsDUi}~u%?aK&Ofvq zmznsKA1nGlxvHIs(lt2v_wn1df2_Iw$++hMe=N%W7@Gsu*?>(EU|S8?HUF5I|2{%@ zRR`S+z=`2G_+wP=wGH%61AYG71Z)QVe~iZscK;!R!O`>JtxJB4%HiK9+>2;uOrlCQr#!pK{P(O4ht7hn7q^>^1kE__j$ z)bRU_c&H|G!H?|EzwX^$@?lS0Ai<+D092^=SNtW8_Wug-_&*`;<^7*@;+~QFpUd`l z$tUEv|2ZcvlRY(aDlc=IS*FVTOjXxRPEe*=R3;bARL{)RATl+}GqoBrwOccFUS{g{ zW$MwhOs8g<@v_X#vMlCj&2i0|8{_$zUuHSfWE_l^rOP5 zN27R0x0oG`o_}(Q8?qp?v(dSjU&zDcl@8jivGRyCupa0o4 z|4UH*Kve!9nm?48KTPC*Ezcim$p6-w|NUkDXkY#qonTKRrcNiOnG-4t2vs+NvzAcX zLU6@|`aweDB%yhM(7HxwKOl5|A#^_xdI|-m(+bR{7nqwDSS%=*<5n4EBw-m;S3%4CC zj6Yep{X${FwZa__3KM@RO!`!qtWbJF z=+ zZY$S$Rj%7#t_M_@s#KWiSD0H=SS+lV<6bd0xWY2J!YZX=URH&5L4{34g>7SnU0a3y zs|tty3P+%luTtr%U+HF1>AtXXwR@#UaHVH-rB_O&cUGlOLFJl?O5esxzqU&MSCs+% zm4QH&P^BtLziNv`RrJEDt?pGZ!Bw%*RdFd*+p?B!Ib);toq@C`mYuBBaQXn+Umc*svqsI9|IbYNx)PkXH%pBp9H8V zxtKaNARxd|@-!7R;2^+N3NRHlP*;X(D215f4LAg7DMgq{8xR=KQHnK{HBbrNdP+&A zv_?b;FjHzWMH+D>z(T3T)Tt3+0&|txO$CiO6R=Y1G8Hu<%7C>}w<+F;D+9Jly{6Je zgaz0u4VcOraTegHG-^t_fv5tmO3r4;4O|s)S8_3Px`A*24<%1C!3~@Pcqs*#iEbcj zfR9p$8GZv-1ALVt%%nFEF5s^eYbLvabAdpmBs1DgL>q`wYBED^;@Uv8Qj3|>O+*KX zQEE36+{AT&IHfK#(M?1bh*#=1!*AlcK!Q@Qne-;22P7&Dn8|M9dO)($s2S}R@)MA$ zbHy7N&X9879UFM=Y$Sj~n zsoNaCgU<&H~s8<>_r`<#50If>S7RWt(4$!9LV&QZTnF~Bp z^0W}#!{-8zl>#h8_Yg~~ZU^)$wOBaaN9=(wO6?Yc`?x(YsMKX4x{o*j!%E#2_v8o#!B}_|Jd}(`AlRE3$}69QT|fXvG%+TxP%=Q7hsMXfQ+O z;H|hbpv8=sBW*<%13JvuIkHxKF`&mxnnP1iqG#Mc8y znE{rfP9y}#VTM@Zop=b4%Z#v;b|M>qJZ7w=tP|e=OKsmF?5_yVm1}c~> zmQGKRaG;XeZYg+*hXYm2E=$oiEnm`PT&7sw8vkJ)5}yufz={md3Crx!>f@P*lKC3t}+0)xyhE71!i z2^eN}Tj4M8Bw&QuYbAYwBm>`>16HyZcrq}?9JQjoM0NsGm7V7yFY%p#in7Z*r}11ywV<~hAW_5*X3+vf>h;rjtA<*s?6SI7atTDf~3{t7<;*edtVlfFVS z0DI+ud9qh{2H>bXI*;}%atLr$cD6=-#Sa1Q$}ZMUzaobL4`okl!LRsXz)Ly6TJ$S& z1n^M~vBrPJj{v^P5!TXQk)wdWa;&xNSNtdtsGMX?dxIPYqLiDgkvI5pAX>S_+UX6F z2gE41TMORcc|e?Um$m2(asr4~?zYC?;3t3tIlxBr7C8gtD2LeKZ}Brgu5yHp^es{d zX+3V5y@Yb*PJR{<}SlWb`pkt;x- za+59c5x)ZTE4SD>eMIVjFUsw)B)Q8*vI4n;)K_7kt;IabjM14pTpuq~U z!~5_iK#LV&C+$OS0y?Z%J6Ruo6VPKN+0puuyMP(1$qwnq?*bOA7CWbY_vk}7ZA_tw#NtYE+B!` zYcCx{UI2-#0ejgX{sKs5joQ@L#$N+jEEfl-VdPifFw4_HFpU2S9AyPK zh=!4FAcqy=fDhx{KrSo7K{|}Q0rFU}4zgkV4Uo@Da-fYM?}2hwlLIn>zXvKRF=>wC~7g zpq1t9hA&SrLxX@5mtV zoE7UR`;HF+FIY*Av@zrx(8p?WM8@!MKtHR+(P<3%4t!y?I||0|@4z6d%TY9ji~_@~ zZby6!9|cBOy^hi`WDNMu8gP`2;bXuUYt)f8pEn7ds_M)~<`a`Z6;&6$(|jHXa#TI} zg82jpa#aKPqWL@s)KCrKm3*TuWj|t9IZRZOX5=_uawTmxW$WsQbRlE84LP8m|Rqf?V7xGx3z3Kp8wvb?f zj;f=4+9IAR=&I`MiYy{jL3dRbSEofh4(Ork=_*)6a6m8B09Vl>o*L+*8sdsCBGf=% z)d*MVA|4m?SB-U*Eh4yJplXsUZ81+9j8biKMHUms3eHXe)Vhz*bdfcVs0o2W(Syad%qDn+rZt^>i1kB<6yT zRRi2bD|wb+r)r2hzLKy6pQ=Wi+Uo#W!+ zEy#(4<6=tdJ24q_25xXfTzfZ7Xs#R zLOgLFA_UCkM0iSlcpJbxPOPWQhu8q-bCNu1zP!y~Ij6}J@g+8c6`U4NCtqGTSjlPk z6!;S1U=^p!Q{>Bw0Bbnip13a&0oHPQJ*B?9NU)AG;3@MZBEfpjs3*;zw-s#VIC~-f z#8$A4C+r2f1(@Hr>eOXg3+ zfiE~oUbH~o4zQ2YU47Dyz6W1LYhS`cq1I91Kr8wnzIf+}h*-cCWhRFI?Q=`9E%QbDd- zfVU`!w+qxz3-QK-h+UwTT7Knt}N zZ>M11esHc@ySE^i*biE%b$N?|c?UpiwQg@bm^c91s`YwHgLxUCz1o1cESSgu9o0s? zY3q51Kvy+qA7njo2y|C-@o`$uI}CcLdHM*}6Nf=BwE!Q{dfpMxM=it$Ur!tXebpj- zr0aP{L4UPaAK7~1C>W@g#ELUq ze?biK2=w9x_={qA?Vt}g#2=3#+Cg7#gugU~_Zal&#`?=*h{s?cH_4wC$9o1wahv>+ zIN})?&28~_isL;8W4P`9f;i$i7{~4M7sc_qz<6%AKORSPfeGARe`y@=1(?Vk@R!9A zFTiB(s6Q>9_ZrOPItL)}#A`5%>k{A;&-)cT%=HWq#1p@QN4Wt3qIh06n8OVTz~hN- zFqaz?cn?-^TLPRCcptz@ZhL?rf%pJc zak~OU3A|pghT9#0ClI}0Ew?v7n!x)A)^P^{WC_Gau%0^_KuhF(23xt#fk-0p8EoUa z1Ue=1zJQOoo`Hfy;tTkg8xSZ;k&#kr)7!`%+dA~3K|Fy=4uBBMGZU% z*3b?S;td1@YiUOar42k7*3pg?${Gk5*3(WB(i(Y6lq1z7L>dVtDi)+g=+ww#QnFOL zP|!#)seF(wp{S9kOj%RiLcEbsrh-Cxh0;bIi_)hCgtA70MJ0xe3TZcZs<5lJa};ue zP=(#KU807v;(3j?9%@yBT#8z)+(JLAAsC}pKv405ZEdy}injJ3RoYsWpomop z7_OmKrJxjTYw-d_i}1}%vNz=XIzP_y`1r6hd(X^Tv(~%bwPt5B0ehFTIaPlHcg5Cl zl&Kvz2v;nQqe`XS#5v$ytD02)O@afC7pYH`+{C%yrK{Fd#Z7_>&KbF%s=kR&0dHS* zr|NDJQ^0{EPgC_baUOULt3S2lCcy(Ij|`>KZsF6g{hXjQ{w-n}c90X1Cb@-A#}08; zrYUX_)3GC*wQ1^G_zbL=lboizMa;lH;cQOR-@<2NpK+9F9k+;?*ykKo8m$%=VCOkC zY5ZD30FEiCPm|Q*Uf4xWYnq~#@B-(R+)q>2;54kS7i;6JO;^|9bFp?#a=NaL zn2X)zY);qL;q$P29A$b(9Wf7kz)_{s8t_HfGfqu9zkyf;jytJOmo(u1SRbb~UC}`J zgL6;rr>h(A0PGc~J6+d61b_ojo~G*?a546V)1ThaK#0LfC`0Kq11`m!xIr0w10ltn zxe*x>1HKd+&s~|JFc3>ISMJ&jwE+*pCUKK9bOs^_bK`E#&>QezjLTJKbQp+WjK@`F z(3N6#6 z_!{soT5G1FjaUQDVY#2FZo}7tm(sd3b#262a3IUmOnn<32i{KW&+KR;;=su)Lz%R8 zJQ3T^4N~&kiA3xmH$o|C$CI!_+?7g2JCTGP;jUGx+wo+qn47HBwG+wMC)~|SeLMah z_8C{H>}V(6!#?M#l(f6}``}f!8YTZO@jf`lrCuqyi*Ew&w6!V~cZp5lJeT`Q^<8{3 zc)6`xsk=*T1_!%5RqF5JA7EFw{mPEJ#0TJnmmwwX9-e_2xItO`dqf7-#Er<3+`}`m z7VgR{#XTYuYvZoXQs2XsSUWd4OLvb@Vt2Wlv-J1yEbJavnbmQR$ig0QRavwL_%`er zwwGymb}FAn>i|2-%N-o2 zChj=J;Y)vfD4Q?YQ%F3si>Y1iql#XU7iX;P3@(TZD{f@zerEX%KZp@e9WT>_RqHM> zetqxu@VpjQm?SzjReeD)zH8o!8w1RX1M)7`xO6P|fpw;#T$+bwkk^ubdystDb{9f~^_yeb*eK0=z(9l^rtH^4R~Q zT{KImSpy_=jZS_*Z!M3>dk{0eYff`#a6ZscT#6UH`+N?R2K2zMlSq}l0@M7imL&$9 z*hfFb`w0psz4jtly)oz}?sxqWTs6>+ zT(j%_LJDQ;h^ks_#CC`dPPHxk2uhpnG`4}~TQQkbhHBa(RnaTWaXbyRR?%+o#l!RwAM?Tc)n--`Njoz;I%Ynj4XQm&yys~MygqUX`*DH3SKUq zwkW%3{n5W{q0Fw#s6qM3YQ~=3^ja|GY(>GG@mnWJ#*93ew91e6S0{YXh96$JQ=~n{ z_O_kO2O@}_#@n&C5wMAQk+bLQ#gVDY?H2TYcoMcaHHK;XS)8}n9&lN9idZKSyWEgM zy(Q0rix;pto319!2NaIPa%V6xQ;_Dc>gDVu=^l8C`p#+-mo%JJ6`0Th%@f&UbM5@Q2mlk>O5J4fw~B zA*(8VqW>cTb~{fJ-Sf5ndpqbZuDOh`{_;Vm(-*D2u;`?CjXmktVqnn>#`0ynsEt$M znf00_HumxEbuWW8Y4)WSQY9Vpyt5D81B5&nj)QxO*zE5%vJudv`kxKZ&k~nGq4zKG zT|ykRI%g?dk>|2OEnJ2<&Fi}C0%-r|k8WNnT^&h$&1edSQVMbK9bfG=Y}+vz)F*;u z*&LvTr&e`d25>{ApdWrXXi#)o#ED`QC*DZDHk%Q<(M9t!M*GB~8>iEyjMMWWh!)6X z(Jm7xWFJz;fTh+vVXK$BK~oqmmpZgC7lf2O4n=b`AVl5qj9py>v&nE}R+Xrxnv7;Q-Sn?fUi4lGr* z@iK)*&E-Z7&-bv9bg_VoV7Y<|?z4-;Zjj_;FInvz6(&u5Y%zYZ%3(;2SNdR@t?UD{DS}hC}lD@du2vVarI{o&?rTn8G z95lgSZi2P;QsFx2QkB-sS%%GsleuU_cA>rY?<4A0D_LFelzm~98^rQ#AAa6vwp6^} zAL4ovp%_q9B6E%!4Kc+*`0br-Iw6E+%=Ecu8YJ(hluR}d# z(`8L2U!*%;6)l~$XbWp1@&v#;gNYN2;(4%xN&%z}^X;O{9mE;;@caSWs*;YGC|sv~2;~JUn}pZ9y>8Qh^L~ zJh0~w6+1yx1B^2Hm__(90E$pjY<=s7HH=t1--zlIexPtVA^^gD@n3eS2Xc|ss6}sb z6I?F&tEVhPKji`Vz_loT_==Qg>RIb(@Rne#-!GcD#rV094tXW)t7A#Lo)W;-6spT9 z0kkMo&lQjYsNzvj_H>9ft54{6z!V#ZE7s?_P=KxU3sM}Q><30OWIri?sZ-uH1IN+x zmLvG|KPSntH8WscC1neY>m9+x;@B+ec0ifr-L4zv>jj7g8t6>TVu#DukfVb27t_Ou zT*hT#1CTC`*7y?W$90xp)SqMsXky94TG7F$tQ30`0swQ|D^_ncg#dQrn0=?7F?RMd zAlxEDxfIf&3PTpW-%)z7#Co0C9&+=){6ihE0*Vs}_z|8n&jX>DbKCXiv$C@F%L<{J zIp5(dSfJbc0{(3Z`O)z=DcNAEP%JinNXZ8CrEb?_3)Sybz$bv&Dv3JAFGmrD=8X*$ zji2nN{{x8v0x#H)g)yx)M~T=;6)jto@@f1{I3-tRZ*YDIJ1=1JBIp-g(mZb$S@h#I z8>!Z!1MHqoR~Tq2)n3^W8Tf)aki*r` zf1e6rNbKGGJh*r|mG@D81oOUK)~&~pvjZ)Mztu-e>(awOVQ&N)ndVema7SxHr8F$2 zhUYV+)6pODQNK)CUvU0>%Vrf(Lf?Z^%|&M=>ktk*Ivv_;tJk2_@#eSil9nxWtUCmi zSKa4Q;Exf;P2#i&OoT#AWO1R$GXa1%iD}Jp=lS{dxp_rYu}Lol*@fK==$P z4$<4pydpG8^a`f)-L63+#e^kZkZ{jsKm4~jpTzUiSeNfQh`@Lz&(c<~zuG$)I^l34 zjma<`C;1g)`72ZGpCO?5xd}gku(Rw$pQES+6E(zLZR|)!QveysClAgrvq@6+5fC?|Af~v2|5E#2U0YI_rAZbRjc{~+L z?x69MIEa2jIB6n!F+Gv*p&ZXu!uO06qa+OcIlh8M&a`XMj)I;?w2C|^QZ(kxtPo1b zUgLR)QcrRsjR2IUP%Y;##;ZOCqzpW~F0p{JeU$7$+xJ*<{s!z%T<1;LZE&W0qyt2r1a_q50LuV zjU26iAhpE$N331eBp*ZffKpQvwAVck94Gd4h7ZDdmY{kCFHQ z@^DZWc{0oKbHra!dM9*f46{4Z-~@rjhT$S!>{!i{_gjou)c%cx>_T%z0E#p<-;kxq zEwA=0wI^wp7*#SaZMPJsIGGi5%mI2Cs6bs?VJx@ktU1=Q^PLDgv0#^`;47x^o zlZl%{VAoVQT(+MMdIAMQ?Ns)dT!bCqV)rWhY2S?Gcg^MC$zsOqw;?vi$znj*-ip0{ z5OMe_5H5Xh&)0=shwG{Zu|zs4Sx=gksZIsWcbWZ9w$l5mD?iS~Uws zlF5sLZ0Y*O?KRL^g28mOT=^)&f+)JotD|-#KaU2rrsC(RS!f!gpv5u7ATM>o*1tt+d-@WF0*V(?>mdBiU`Uon?3p^b9>Zq zy3TUkWhLG0?*SUoPGW^(%IVEOBE+donPoBEmL&XQkgg>m zc@8o3Z$m)K28a{XG6{y6vR^FNJ~v z42}JY0v}EO0Cu8*VLXM_>GNjoX$oC=V+(k`LyEnWULhRk$Hgi~muO9t;B3g9jV7l7 z%yEOT@By-HSUdp5CxC@k4$L3Oc>a>n2$J^O11KdHJIZDWz3*Gzp1;68qa5CDT6^Fc z8WsDe_nbN4qwM8h&J0WxRXqce%)HQP?|2Eiw-}ZsU=_*f^k+eok{5f2iD8LLn+FJ-s{{PpKK0uK_VH zTgv|F)(6B#S)_g2ILcJF(4n4VwG{%Y%F6g^NvJ1c;eXX3{;T}xJIJZ=na| zCNxJ9HyJl_sZz=>OX8i$s0HvYdd~@V$5b?95I3wEV|Gm;Y2uK=73?yBT^@=dHS)_#=k7Hs~7B9*9W zg6#ZnT4^7u{J?}NJpTKU!csrjq=AK{3-neTDqL*ZW9(+U=|lq#`%PHhBdEGk76=h_ zPJLAD_R)yyl<;h%9y%9nDIBxkSpfkJ;FKw5Xu?UNDRlg?lV#^d0oYA6h!^sI+S7=97iY23X8JYPoI9%Z)iQ7#MAx1%Yz0hPBILcNO~G`q%1( z|Dbh$j~~Yxmfp{C2=!V zPCI0+ZhUh)eB}b{MeSw(IC}!5FyhvuN-7k=k}b&Q0vyA{i2T9rg6&i;Rw;nHIh!o6 z^R<@O1(}xD#pf`xeh+tle1Ou9g_emoq5eVO&~!gm!e|S(1mc)7EmKP5TAwXNH8p5_ zZi5mMI^o^o*dB}PIHQibb^+T(%^t7xXZ}X^0$&G-h|jYkao!bRU*04t4~Dx5zWp8~ pAi>@Im$4^kpJn7c6z0BW(dL<6VgT=!SiyIu{vZ4u>U{tJ literal 0 HcmV?d00001 diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py new file mode 100644 index 000000000..226480fa2 --- /dev/null +++ b/python/triton/language/libdevice.py @@ -0,0 +1,1661 @@ +import os + +from . import core, extern + +LIBDEVICE_PATH = os.path.dirname( + os.path.abspath(__file__)) + "/libdevice.10.bc" + + +@extern.extern +def clz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def popc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), + }, _builder) + + +@extern.extern +def min(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def max(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def mulhi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def mul64hi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), + }, _builder) + + +@extern.extern +def mul24(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def brev(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), + }, _builder) + + +@extern.extern +def sad(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def abs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def floor(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcp64h(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rsqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ceil(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def trunc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def exp2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def saturatef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fast_fdividef(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ddiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), + (core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def float2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def hiloint2double(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double2loint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2hiint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def ll2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def int_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def float_as_int(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float_as_uint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def longlong_as_double(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double_as_longlong(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fast_sinf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_cosf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log2f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_logf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_expf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_tanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_exp10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def pow(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def rhadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def fsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ffs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def rint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llrint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), + }, _builder) + + +@extern.extern +def nearbyint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def isnanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def signbitf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def copysign(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def finitef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinff(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + }, _builder) + + +@extern.extern +def nextafter(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinpi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cospi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan2(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log1p(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def expm1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rhypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def yn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def jn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcx(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def lgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ldexp(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def scalbn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fmod(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def remainder(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def powi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def round(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llround(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fdim(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ilogb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), + }, _builder) + + +@extern.extern +def logb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def signbitd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isfinited(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinfd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isnand(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), + }, _builder) + + +@extern.extern +def dsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), + }, _builder) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py new file mode 100644 index 000000000..6d0a04e8e --- /dev/null +++ b/python/triton/tools/build_extern.py @@ -0,0 +1,340 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod + + +class Symbol: + def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + ''' + A symbol is a function declaration. + + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = arg_names + self._arg_types = arg_types + + @property + def name(self): + return self._name + + @property + def op_name(self): + return self._op_name + + @property + def ret_type(self): + return self._ret_type + + @property + def arg_names(self): + return self._arg_names + + @property + def arg_types(self): + return self._arg_types + + +def convert_type(type_str): + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str): + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + ''' + Abstract class for extern library. + + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = True + self._grouping = grouping + + @property + def name(self): + return self._name + + @property + def path(self): + return self._path + + @property + def symbols(self): + return self._symbols + + @property + def grouping(self): + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file): + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir): + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], + stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + + def _extract_symbol(self, line): + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self): + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + # The following cases are grouped together: + # op_name, op_name + for symbol in self._symbols.values(): + op_name = symbol.op_name + if "max" in op_name: + op_name = "max" + elif "min" in op_name: + op_name = "min" + elif "abs" in op_name: + op_name = "abs" + elif "pow" in op_name and "fast" in op_name: + op_name = "pow" + elif "round" in op_name: + if "llround" in op_name: + op_name = "llround" + else: + op_name = "round" + elif "rint" in op_name: + if "llrint" in op_name: + op_name = "llrint" + else: + op_name = "rint" + elif op_name.startswith("ull"): + if "2" not in op_name: + # e.g., ullmax->max + op_name = op_name[3:] + else: + # e.g., ull2double->ll2double + op_name = op_name[1:] + elif op_name.startswith("u"): + if "2" not in op_name: + # e.g., uhadd->hadd + op_name = op_name[1:] + else: + # e.g., uint2double_rn->int2double_rn + op_name = op_name[1:] + elif op_name.startswith("ll"): + if "2" not in op_name: + # e.g., llmax->max + op_name = op_name[2:] + elif op_name.endswith("ll"): + op_name = op_name[:-2] + elif op_name.endswith("f"): + op_name = op_name[:-1] + if op_name in symbol_set: + # Update op_name only if there's an existing symbol + symbol._op_name = op_name + else: + op_name = symbol._op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file): + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self): + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return extern.dispatch("libdevice", , , , _builder) + import_str = "from . import core, extern\n" + import_str += "import os\n" + header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@extern.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\")," + ret_type = f"core.dtype(\"{symbol.ret_type}\")" + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += ", _builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + def __init__(self, path): + ''' + Invoke llvm-dis to disassemble the given file. + + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path): + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], + stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self): + return self._ll_file + + @property + def path(self): + return self._path + + +extern_libs = ["libdevice"] + + +def build(llvm_dis_path, lib_path, lib_name, output_dir): + ''' + Interface function to build the library file. + + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") + parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/python/tutorials/07-libdevice-function.py b/python/tutorials/07-libdevice-function.py new file mode 100644 index 000000000..bb5f7b26d --- /dev/null +++ b/python/tutorials/07-libdevice-function.py @@ -0,0 +1,74 @@ +""" +Libdevice function +=============== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. +Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions. + +In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Using triton, you can simply call `tl.libdevice.asinf`. +triton automatically selects the correct underlying device function to invoke based on input and output types. +""" + +# %% +# asin Kernel +# -------------------------- + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + +# %% +# Using the default libdevice library path +# -------------------------- +# We can use the default libdevice library path encoded in `triton/language/libdevice.py` + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='cuda') +output_triton = torch.zeros(size, device='cuda') +output_torch = torch.asin(x) +assert x.is_cuda and output_triton.is_cuda +n_elements = output_torch.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) + +# %% +# Customize the libdevice library path +# -------------------------- +# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. + +output_triton = torch.empty_like(x) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, + extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) From 0a3f3d5f250394a2b1399c247a96c7839c6d02a1 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 13 Jul 2022 23:45:27 -0700 Subject: [PATCH 153/215] [PACKAGING] Include triton/language/libdevice.10.bc in package data (#582) --- python/setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 6c136b6c7..b0e48f251 100644 --- a/python/setup.py +++ b/python/setup.py @@ -141,7 +141,11 @@ setup( "filelock", "torch", ], - package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, + package_data={ + "triton/ops": ["*.c"], + "triton/ops/blocksparse": ["*.c"], + "triton/language": ["*.bc"], + }, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild}, From 5b04331dd2efdd23f4475823761fa975de60a514 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 13 Jul 2022 23:48:58 -0700 Subject: [PATCH 154/215] [TUTORIALS] Added more credits in fused attention tutorial --- python/tutorials/06-fused-attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 89aadb1b4..c19ee498a 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -1,7 +1,8 @@ """ Fused Attention =============== -This is a Triton implementation of the Flash Attention algorithm (Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf) +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) """ import pytest @@ -349,5 +350,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms - -bench_flash_attention.run(save_path='.', print_data=True) +# only works on A100 at the moment +# bench_flash_attention.run(save_path='.', print_data=True) From 86cab58d89909aebd8970ca80a3e85220dd033d9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 18 Jul 2022 14:54:13 -0700 Subject: [PATCH 155/215] [CI] Changed dev wheel date to UTC time to match CRON schedule (#587) --- .github/workflows/wheels.yml | 2 +- python/tutorials/06-fused-attention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 1d8d450f2..db682f33f 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -18,7 +18,7 @@ jobs: - name: Patch setup.py run: | #sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py - export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g') + export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd") sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py echo "" >> python/setup.cfg echo "[build_ext]" >> python/setup.cfg diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index c19ee498a..fb0f4f958 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -1,7 +1,7 @@ """ Fused Attention =============== -This is a Triton implementation of the Flash Attention algorithm +This is a Triton implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) """ From 9b2bc88d11af01b26f0bf7e74631f095e00f36d1 Mon Sep 17 00:00:00 2001 From: daadaada Date: Wed, 20 Jul 2022 12:22:37 +0800 Subject: [PATCH 156/215] [BACKEND] Better bf16 support (#588) --- lib/codegen/selection/generator.cc | 86 ++++++++++---- lib/ir/constant.cc | 2 + python/test/unit/language/test_core.py | 105 +++++++++++++----- .../test/unit/operators/test_cross_entropy.py | 8 +- python/triton/language/semantic.py | 39 +++++-- python/triton/ops/cross_entropy.py | 2 +- 6 files changed, 180 insertions(+), 62 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b30283ced..31ecfacba 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() -#define bf16_ty builder_->getBFloatTy() +#define bf16_ty builder_->getInt16Ty() #define f32_ty builder_->getFloatTy() #define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() @@ -178,7 +178,7 @@ Type *generator::cvt(ir::type *ty) { case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); + case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); @@ -378,8 +378,8 @@ void generator::visit_launch_inst(ir::launch_inst *launch) { */ void generator::visit_binary_operator(ir::binary_operator*x) { using ll = llvm::Instruction::BinaryOps; + using tt = ir::binary_op_t; auto cvt = [](ir::binary_op_t op){ - using tt = ir::binary_op_t; switch(op) { case tt::Add: return ll::Add; case tt::FAdd: return ll::FAdd; @@ -406,20 +406,51 @@ void generator::visit_binary_operator(ir::binary_operator*x) { for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; - auto op = cvt(x->get_op()); - if(op == ll::Add) - vals_[x][idx] = add(lhs, rhs); - else if(op == ll::Mul) - vals_[x][idx] = mul(lhs, rhs); - else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && - x->get_type()->get_scalar_ty()->is_fp32_ty()){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), - " div.full.f32 $0, $1, $2;", "=r,r,r", false); - vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); + // manually select bf16 bin op + if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) { + assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty()); + if (x->get_op() == tt::FAdd) { // a + b = a * 1.0 + b + InlineAsm *bf16_add_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + "{ .reg .b16 c; \n\t" + " mov.b16 c, 0x3f80U; \n\t" // 1.0 + " fma.rn.bf16 $0, $1, c, $2; } \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a + InlineAsm *bf16_sub_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + " { .reg .b16 c; \n\t" + " mov.b16 c, 0xbf80U; \n\t" // -1.0 + " fma.rn.bf16 $0, $2, c, $1;} \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs}); + } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0 + InlineAsm *bf16_mul_asm = + InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), + " { .reg .b16 c; \n\t" + " mov.b16 c, 0x8000U; \n\t" // 0.0 + " fma.rn.bf16 $0, $1, $2, c;} \n\t", + "=h,h,h", false); + vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs}); + } else + throw std::runtime_error("invalid bin op for bf16"); + } else { // not bf16 + auto op = cvt(x->get_op()); + if(op == ll::Add) + vals_[x][idx] = add(lhs, rhs); + else if(op == ll::Mul) + vals_[x][idx] = mul(lhs, rhs); + else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && + x->get_type()->get_scalar_ty()->is_fp32_ty()){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), + " div.full.f32 $0, $1, $2;", "=r,r,r", false); + vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); - } - else - vals_[x][idx] = bin_op(op, lhs, rhs); + } + else + vals_[x][idx] = bin_op(op, lhs, rhs); + } } } @@ -970,8 +1001,6 @@ void generator::visit_store_inst(ir::store_inst * x){ has_l2_evict_policy = false; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; if(ty->isIntegerTy(1)) ty = builder_->getInt8Ty(); for(size_t i = 0; i < idxs.size(); i += vec){ @@ -2830,9 +2859,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // pointer to temporary shared memory Type *ty = cvt(out->get_type()->get_scalar_ty()); - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; - // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); @@ -3229,8 +3255,22 @@ void generator::visit_constant_int(ir::constant_int *x){ void generator::visit_constant_fp(ir::constant_fp *x){ Type *ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ConstantFP::get(ty, x->get_value()); + for(indices_t idx: idxs_.at(x)) { + // manually select bf16 constant + if (x->get_type()->get_scalar_ty()->is_bf16_ty()) { + // highest 16 bits of fp32 + float fp32_value = x->get_value(); + uint16_t bf16_raw = (*reinterpret_cast(&fp32_value) + & 0xffff0000) >> 16; + std::stringstream const_str; + const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned + InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false), + " mov.b16 $0, " + const_str.str() + ";", + "=h", false); + vals_[x][idx] = builder_->CreateCall(bf16_const, {}); + } else + vals_[x][idx] = ConstantFP::get(ty, x->get_value()); + } } void generator::visit_alloc_const(ir::alloc_const *alloc) { diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index ab1f6f497..417626c92 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) { return constant_int::get(ty, 0); case type::FP16TyID: return constant_fp::get(type::get_fp16_ty(ctx), 0); + case type::BF16TyID: + return constant_fp::get(type::get_bf16_ty(ctx), 0); case type::FP32TyID: return constant_fp::get(type::get_fp32_ty(ctx), 0); case type::FP64TyID: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index cb2cb9c33..561ed6af5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -33,27 +33,37 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h shape = (shape, ) if rs is None: rs = RandomState(seed=17) - dtype = getattr(np, dtype_str) if dtype_str in int_dtypes + uint_dtypes: iinfo = np.iinfo(getattr(np, dtype_str)) low = iinfo.min if low is None else max(low, iinfo.min) high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) x = rs.randint(low, high, shape, dtype=dtype) x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. return x elif dtype_str in float_dtypes: - return rs.normal(0, 1, shape).astype(dtype) + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') + & np.uint32(0xffff0000)).view('float32') else: raise RuntimeError(f'Unknown dtype {dtype_str}') -def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: +def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type becasue the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' t = x.dtype.name if t in uint_dtypes: signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" x_signed = x.astype(getattr(np, signed_type_name)) return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) else: + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() return torch.tensor(x, device=device) @@ -72,6 +82,8 @@ def to_numpy(x): if isinstance(x, TensorWrapper): return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() return x.cpu().numpy() else: raise ValueError(f"Not a triton-compatible tensor: {x}") @@ -84,19 +96,30 @@ def patch_kernel(template, to_replace): return kernel -@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) +def check_type_supported(dtype): + ''' + skip test if dtype is not supported on the current device + ''' + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + + +@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"]) def test_empty_kernel(dtype_x, device='cuda'): SIZE = 128 @triton.jit def kernel(X, SIZE: tl.constexpr): pass - x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) + check_type_supported(dtype_x) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) # generic test functions def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): + check_type_supported(dtype_x) # early return if dtype_x is not supported SIZE = 128 # define the kernel / launch-grid @@ -115,8 +138,8 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): # reference result z_ref = eval(expr if numpy_expr is None else numpy_expr) # triton result - x_tri = to_triton(x, device=device) - z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) @@ -154,6 +177,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): + check_type_supported(dtype_x) # early return if dtype_x is not supported + check_type_supported(dtype_y) SIZE = 128 # define the kernel / launch-grid @@ -180,8 +205,8 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y= if dtype_z is not None: z_ref = z_ref.astype(dtype_z) # triton result - x_tri = to_triton(x, device=device) - y_tri = to_triton(y, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) @@ -193,15 +218,20 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: # remainders than stock LLVM. We currently don't expect to match it # bit-for-bit. return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), ('int32', 'float16'), ('int32', 'float32'), + ('int64', 'bfloat16'), ('int64', 'float16'), ('int64', 'float32'), ('int64', 'float64'), + ('uint16', 'bfloat16'), ('uint16', 'float16'), ('uint16', 'float32'), + ('uint32', 'bfloat16'), ('uint32', 'float16'), ('uint32', 'float32'), + ('uint64', 'bfloat16'), ('uint64', 'float16'), ('uint64', 'float32'), ('uint64', 'float64'), @@ -215,15 +245,15 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes - for dtype_y in dtypes + for dtype_x in dtypes + ['bfloat16'] + for dtype_y in dtypes + ['bfloat16'] ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. @@ -266,8 +296,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['&', '|', '^'] - for dtype_x in dtypes - for dtype_y in dtypes + for dtype_x in dtypes + ['bfloat16'] + for dtype_y in dtypes + ['bfloat16'] ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' @@ -337,7 +367,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes + (dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16'] ] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes ]) @@ -732,9 +762,10 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in ['min', 'max', 'argmin', 'argmax', 'sum'] - for dtype in dtypes + for dtype in dtypes + ['bfloat16'] for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested # triton kernel @triton.jit @@ -752,9 +783,18 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): 'argmin': np.argmin, 'argmax': np.argmax}[op] # numpy result z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str - z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), + device=device, dst_type=z_tri_dtype_str) kernel[(1,)](x_tri, z_tri, BLOCK=shape) z_tri = to_numpy(z_tri) # compare @@ -770,7 +810,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes + (op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16'] for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] @@ -805,11 +845,19 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str + z_tri_dtype_str = z_dtype_str # numpy result - z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), - device=device) + device=device, dst_type=z_tri_dtype_str) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) z_tri = to_numpy(z_tri) # compare @@ -834,10 +882,11 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) - for dtype in ['float16', 'float32'] + for dtype in ['bfloat16', 'float16', 'float32'] for shape in [(64, 64), (128, 128)] for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested # triton kernel @triton.jit @@ -852,16 +901,16 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # input x = numpy_random(shape, dtype_str=dtype_str) # triton result - z_tri = to_triton(np.empty_like(x), device=device) - z_tri_contiguous = to_triton(np.empty_like(x), device=device) - x_tri = to_triton(x, device=device) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), BLOCK_M=shape[0], BLOCK_N=shape[1]) pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1]) - # torch result + # numpy result z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) @@ -1038,8 +1087,10 @@ def test_arange(start, device='cuda'): # Testing masked loads with an intermate copy to shared memory run. -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested + M = 32 N = 32 K = 16 diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 08516257b..e28db4815 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -2,18 +2,22 @@ import pytest import torch import triton +import triton._C.libtriton.triton as _triton @pytest.mark.parametrize("M, N, dtype, mode", [ (M, N, dtype, mode) for M in [1024, 821] for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32'] + for dtype in ['bfloat16', 'float16', 'float32'] for mode in ['forward', 'backward'] ] ) def test_op(M, N, dtype, mode): - dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80 and dtype == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] # create inputs x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a31fec384..8878a8195 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -58,14 +58,22 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t return tl.float32 # 3 ) if one operand is half, the other is implicitly converted to half # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp if a_ty.is_fp16() or b_ty.is_fp16(): if div_or_mod: return tl.float32 else: return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 if not a_ty.is_int() or not b_ty.is_int(): assert False - # 4 ) both operands are integer and undergo + # 5 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" @@ -768,16 +776,25 @@ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: type checking + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, + op: str, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + element_ty) if ptr.type.is_block(): if mask: mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) @@ -798,7 +815,7 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) sca_ty = val.type.scalar # direct call to atomic_max for integers if sca_ty.is_int(): @@ -830,7 +847,7 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) sca_ty = val.type.scalar # direct call to atomic_min for integers if sca_ty.is_int(): @@ -870,7 +887,7 @@ def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) sca_ty = val.type.scalar op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) @@ -880,7 +897,7 @@ def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) @@ -888,7 +905,7 @@ def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) @@ -896,7 +913,7 @@ def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) @@ -904,7 +921,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) # ===----------------------------------------------------------------------===// @@ -978,6 +995,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: input = cast(input, tl.int32, builder) + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is tl.bfloat16: + input = cast(input, tl.float32, builder) + # choose the right unsigned operation if scalar_ty.is_int_unsigned(): int_op_to_unit = { diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 910417d2c..63ce81074 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -65,7 +65,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): # write result in-place in PROBS dout = tl.load(DPROBS + row) din = (probs - delta) * dout - tl.store(PROBS, din.to(tl.float16), mask=cols < N) + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) class _cross_entropy(torch.autograd.Function): From af85f5fa468b5d43d637a1ddf3f69ff984fb7ee0 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 20 Jul 2022 17:34:07 -0700 Subject: [PATCH 157/215] [FRONTEND] Refresh cache when the source code of outlined functions are changed (#590) --- python/src/triton.cc | 8 +++++++- python/test/unit/runtime/test_cache.py | 20 ++++++++++++++++++++ python/tutorials/03-matrix-multiplication.py | 8 ++++---- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index fcebeeb5f..260f83942 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -236,8 +236,14 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } // argument is `constexpr` - if(py::hasattr(arg, "value")){ + if (py::hasattr(arg, "value")) { py::object value = arg.attr("value"); + // check if value is a callable object using PyCallable_Check + if (PyCallable_Check(value.ptr())) { + throw std::runtime_error( + "constant argument cannot be a callable object: " + + std::string(py::str(arg))); + } py::object name = arg_names[i]; constants[name] = value; py::object repr = py::repr(value); diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d866d6983..fd95dbd38 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -130,3 +130,23 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) spec_type = None if cache_str_match is None else cache_str_match.group(1) assert spec_type == value_type + + +def test_constexpr_not_callable() -> None: + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 39bf8c46a..231d3371c 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -236,8 +236,8 @@ def matmul_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32! - if ACTIVATION: - accumulator = ACTIVATION(accumulator) + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- @@ -261,7 +261,7 @@ def leaky_relu(x): # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel -def matmul(a, b, activation=None): +def matmul(a, b, activation=""): # checks constraints assert a.shape[1] == b.shape[0], "incompatible dimensions" assert a.is_contiguous(), "matrix A must be contiguous" @@ -347,7 +347,7 @@ def benchmark(M, N, K, provider): ) if provider == 'triton + relu': ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, activation=leaky_relu) + lambda: matmul(a, b, activation="leaky_relu") ) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From f28caddbf8c30801a331d4f8e36f2106ad07bb90 Mon Sep 17 00:00:00 2001 From: Da Yan Date: Fri, 22 Jul 2022 00:54:27 +0800 Subject: [PATCH 158/215] [FRONTEND] Allow tl.where to select pointers (#595) --- python/src/triton.cc | 1 + python/test/unit/language/test_core.py | 61 +++++++++++++++++++++++--- python/triton/language/semantic.py | 14 ++---- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 260f83942..81e1b66fe 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -878,6 +878,7 @@ void init_triton_ir(py::module &&m) { .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) .def("create_downcast", &ir::builder::create_downcast, ret::reference) .def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference) + .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) // phi .def("create_phi", &ir::builder::create_phi, ret::reference) // Binary instructions diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 561ed6af5..93063d064 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -17,6 +17,7 @@ int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] def _bitwidth(dtype: str) -> int: @@ -46,6 +47,8 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h elif dtype_str == 'bfloat16': return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 else: raise RuntimeError(f'Unknown dtype {dtype_str}') @@ -245,8 +248,8 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes + ['bfloat16'] - for dtype_y in dtypes + ['bfloat16'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' @@ -296,8 +299,8 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'): @pytest.mark.parametrize("dtype_x, dtype_y, op", [ (dtype_x, dtype_y, op) for op in ['&', '|', '^'] - for dtype_x in dtypes + ['bfloat16'] - for dtype_y in dtypes + ['bfloat16'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' @@ -363,11 +366,55 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) +# --------------- +# test where +# --------------- +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +def test_where(dtype): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device='cuda') + x_tri = to_triton(x, device='cuda', dst_type=dtype) + y_tri = to_triton(y, device='cuda', dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs) + assert (z == to_numpy(z_tri)).all() + + # --------------- # test unary ops # --------------- @pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes + ['bfloat16'] + (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 ] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes ]) @@ -762,7 +809,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in ['min', 'max', 'argmin', 'argmax', 'sum'] - for dtype in dtypes + ['bfloat16'] + for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested @@ -810,7 +857,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes + ['bfloat16'] + (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8878a8195..ecd740114 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -637,7 +637,7 @@ def cast(input: tl.tensor, if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: - return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: return not_equal(cast(input, tl.int64, builder), @@ -969,16 +969,8 @@ def where(condition: tl.tensor, x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) - # TODO: we need to check x's and y's shape? - x_ty = x.type.scalar - y_ty = y.type.scalar - ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) - x = cast(x, ty, builder) - y = cast(y, ty, builder) - if x.type.is_block(): - ret_ty = tl.block_type(ty, x.type.shape) - else: - ret_ty = ty + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) From ab56d310dde4d5e866b2a664cfa4db77b2967339 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 23 Jul 2022 20:08:03 -0700 Subject: [PATCH 159/215] [BACKEND][IR] Fixed up internal dtype size for booleans (1bit -> 8bit) (#600) --- lib/ir/type.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 43761f482..5667dcac5 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -27,7 +27,7 @@ unsigned type::get_primitive_size_in_bits() const { case BF16TyID: return 16; case FP32TyID: return 32; case FP64TyID: return 64; - case IntegerTyID: return ((integer_type*)(this))->get_bitwidth(); + case IntegerTyID: return std::max(8, ((integer_type*)(this))->get_bitwidth()); case BlockTyID: return ((block_type*)(this))->get_bitwidth(); default: return 0; } From e02e56dc632ffae213ef2371ef7218911efcc15f Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 23 Jul 2022 21:54:12 -0700 Subject: [PATCH 160/215] [FRONTEND] Add missing rfloordiv (#598) * [FRONTEND] Add missing rfloordiv * fix tests --- python/test/unit/language/test_core.py | 14 ++++++++++++++ python/triton/language/core.py | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 93063d064..2eadf34a2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1363,6 +1363,20 @@ def test_constexpr_shape(): kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + +def test_constexpr_scalar_shape(): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + # ------------- # test if # ------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 4197a3333..fdf9063a7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -472,6 +472,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.floordiv(self, other, _builder) + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + @builtin def __mod__(self, other, _builder=None): other = _to_tensor(other, _builder) From 027321cdcf803e45913370a4ab153a6b8f4bcaf9 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 24 Jul 2022 17:47:23 -0700 Subject: [PATCH 161/215] [FRONTEND] Make tl.rand() 1-exclusive (#601) --- python/test/unit/language/test_random.py | 21 +++++++++++++++++++++ python/triton/language/random.py | 5 +++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 042065403..39ae59e35 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -175,3 +175,24 @@ def test_randn(size, seed, device='cuda'): kernel[grid](x, N, seed) assert abs(x.mean()) < 1e-2 assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + +def test_rand_limits(): + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint32_to_uniform_float(x) + tl.store(output + idx, y) + + min_max_int32 = torch.tensor([ + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + ], dtype=torch.int32, device='cuda') + output = torch.empty(2, dtype=torch.float32, device='cuda') + kernel[(1,)](min_max_int32, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/python/triton/language/random.py b/python/triton/language/random.py index c95eac9fc..67de92c43 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -91,9 +91,10 @@ def uint32_to_uniform_float(x): Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). """ x = x.to(tl.int32, bitcast=True) - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + scale = 4.6566127342e-10 x = tl.where(x < 0, -x - 1, x) - return x * max + return x * scale @triton.jit From 923d468187082b2783233f4e7f37901e111804d8 Mon Sep 17 00:00:00 2001 From: Anton Kostin Date: Mon, 25 Jul 2022 19:30:03 +0300 Subject: [PATCH 162/215] Update LICENSE (#602) --- LICENSE | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index a50165bb2..4d18c2bdf 100755 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ /* * Copyright 2018-2020 Philippe Tillet -* Copyright 2020-2021 OpenAI +* Copyright 2020-2022 OpenAI * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files @@ -20,4 +20,4 @@ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ \ No newline at end of file +*/ From 968f59027e74744b2df3d21dc8286594450622ea Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 29 Jul 2022 21:36:08 -0700 Subject: [PATCH 163/215] Expose `module.print` in pybind (#604) --- python/src/triton.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/src/triton.cc b/python/src/triton.cc index 81e1b66fe..940c5bc0a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -758,6 +758,9 @@ void init_triton_ir(py::module &&m) { .def("has_function", &ir::module::has_function) .def("get_function", &ir::module::get_function, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) + .def("print", [](ir::module *self) { + self->print(std::cout); + }) .def("reset_ret_ty", &ir::module::reset_ret_ty) .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { const auto metadatas = self->get_metadatas(); From 7b91c7befd5df91603afe3d61b6d3823e36d3f7a Mon Sep 17 00:00:00 2001 From: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com> Date: Wed, 3 Aug 2022 02:12:48 +0300 Subject: [PATCH 164/215] Fix "warning: control reaches end of non-void function". (#607) --- python/src/triton.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 940c5bc0a..f72513395 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -558,16 +558,16 @@ void init_triton_codegen(py::module &&m) { } if(backend == CUDA) return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); - if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + assert(backend == ROCM); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::gil_scoped_release allow_threads; if(backend == CUDA) return cu_load_binary(name, asm_map, n_shared_bytes, dev); - if(backend == ROCM) - return hip_load_binary(name, asm_map, n_shared_bytes, dev); + assert(backend == ROCM); + return hip_load_binary(name, asm_map, n_shared_bytes, dev); }, py::return_value_policy::take_ownership); } From cc7937622293aa4120207fd3a552fcd2af159b79 Mon Sep 17 00:00:00 2001 From: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com> Date: Mon, 8 Aug 2022 03:10:18 +0300 Subject: [PATCH 165/215] Fix deprectaion warning on CreateGEP(Value *, ArrayRef, const Twine &) (#608) This variant of CreateGEP() is already removed in LLVM 14. --- lib/codegen/selection/generator.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 31ecfacba..d6d059859 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -65,18 +65,21 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ if(auto* gep = dyn_cast(ptr)) if(ConstantInt* cst1 = dyn_cast(gep->idx_begin())) if(ConstantInt* cst2 = dyn_cast(off)){ - return (*builder_)->CreateGEP(gep->getPointerOperand(), - (*builder_)->CreateAdd(cst1, cst2)); + return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(), + gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2)); } // ptr + (off + cst) -> (ptr + off) + cst if(auto* bin = dyn_cast(off)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(ConstantInt* cst = dyn_cast(bin->getOperand(1))){ - return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)), - bin->getOperand(1)); + Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), + ptr, bin->getOperand(0)); + return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(), + gep, bin->getOperand(1)); } // default - return (*builder_)->CreateGEP(ptr, off, name); + return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), + ptr, off, name); } //Value* geper::operator()(Type *ty, Value *ptr, std::vector vals, const std::string &name) { From 3e2953f3570acc807f975eee58f1565483841b72 Mon Sep 17 00:00:00 2001 From: Da Yan Date: Wed, 10 Aug 2022 09:59:32 -0700 Subject: [PATCH 166/215] Allow multiple_of and max_contiguous to accept n-d values (#617) --- include/triton/ir/instructions.h | 6 +++--- include/triton/ir/metadata.h | 8 +++++--- include/triton/ir/module.h | 2 +- lib/codegen/analysis/align.cc | 12 ++++++------ lib/ir/metadata.cc | 4 ++-- python/src/triton.cc | 4 ++-- python/triton/language/core.py | 26 ++++++++++++++++++++------ python/triton/language/semantic.py | 12 ++++++++---- 8 files changed, 47 insertions(+), 27 deletions(-) diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 1bad86c33..8a1c3f7cf 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -59,8 +59,8 @@ public: std::string repr() const { return repr_impl(); } // metadata void set_metadata(ir::metadata::kind_t kind, - unsigned value) { metadatas_[kind] = value;} - unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} + std::vector value) { metadatas_[kind] = value;} + std::vector get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} // cloning ir::instruction* clone() { ir::instruction* res = clone_impl(); @@ -77,7 +77,7 @@ public: private: basic_block *parent_; - std::map metadatas_; + std::map> metadatas_; value_id_t id_; }; diff --git a/include/triton/ir/metadata.h b/include/triton/ir/metadata.h index 9d4fb1137..69512c6b0 100644 --- a/include/triton/ir/metadata.h +++ b/include/triton/ir/metadata.h @@ -3,6 +3,8 @@ #ifndef _TRITON_IR_METADATA_H_ #define _TRITON_IR_METADATA_H_ +#include + namespace triton{ namespace ir{ @@ -16,14 +18,14 @@ public: }; private: - metadata(kind_t kind, unsigned value); + metadata(kind_t kind, std::vector value); public: - static metadata* get(kind_t kind, unsigned value); + static metadata* get(kind_t kind, std::vector value); private: kind_t kind_; - unsigned value_; + std::vector value_; }; } diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index aa279af98..1ed0b6646 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -70,7 +70,7 @@ private: class module { typedef std::pair val_key_t; - typedef std::pair md_pair_t; + typedef std::pair> md_pair_t; friend class function; public: diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 1c48a4c05..6bd6e4ef9 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -366,9 +366,9 @@ std::vector align::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); if(auto *x = dynamic_cast(v)){ - unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); - if(max_contiguous > 0) - return add_to_cache(x, {max_contiguous}, max_contiguous_); + std::vector max_contiguous = x->get_metadata(ir::metadata::max_contiguous); + if(!max_contiguous.empty()) + return add_to_cache(x, max_contiguous, max_contiguous_); } if(auto *x = dynamic_cast(v)) return populate_max_contiguous_cast(x); @@ -521,9 +521,9 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); if(auto *x = dynamic_cast(v)){ - unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); - if(multiple_of > 0) - return add_to_cache(x, {multiple_of}, starting_multiple_); + std::vector multiple_of = x->get_metadata(ir::metadata::multiple_of); + if(!multiple_of.empty()) + return add_to_cache(x, multiple_of, starting_multiple_); } if(auto *x = dynamic_cast(v)) return populate_starting_multiple_cast(x); diff --git a/lib/ir/metadata.cc b/lib/ir/metadata.cc index 16bc059c5..9d31963c2 100644 --- a/lib/ir/metadata.cc +++ b/lib/ir/metadata.cc @@ -3,10 +3,10 @@ namespace triton{ namespace ir{ -metadata::metadata(kind_t kind, unsigned value) +metadata::metadata(kind_t kind, std::vector value) : kind_(kind), value_(value) { } -metadata* metadata::get(kind_t kind, unsigned value) { +metadata* metadata::get(kind_t kind, std::vector value) { return new metadata(kind, value); } diff --git a/python/src/triton.cc b/python/src/triton.cc index f72513395..a606c051a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -625,13 +625,13 @@ void init_triton_ir(py::module &&m) { .def(py::init<>()); py::class_(m, "value") - .def("multiple_of", [](ir::value *self, int val) { + .def("multiple_of", [](ir::value *self, std::vector val) { if (auto *instr = dynamic_cast(self)) { instr->set_metadata(ir::metadata::multiple_of, val); } else throw std::runtime_error("multiple_of"); }) - .def("max_contiguous", [](ir::value *self, int val) { + .def("max_contiguous", [](ir::value *self, std::vector val) { if (auto *instr = dynamic_cast(self)) { instr->set_metadata(ir::metadata::max_contiguous, val); } else diff --git a/python/triton/language/core.py b/python/triton/language/core.py index fdf9063a7..29a128321 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1088,21 +1088,35 @@ def debug_barrier(_builder=None): @builtin -def multiple_of(input, value, _builder=None): +def multiple_of(input, values, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - value = _constexpr_to_value(value) - return semantic.multiple_of(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) @builtin -def max_contiguous(input, value, _builder=None): +def max_contiguous(input, values, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - value = _constexpr_to_value(value) - return semantic.max_contiguous(input, value) + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ecd740114..a860dc763 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1090,13 +1090,17 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: ## -def multiple_of(x: tl.tensor, value: int) -> tl.tensor: - x.handle.multiple_of(value) +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.multiple_of(values) return x -def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: - x.handle.max_contiguous(value) +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.max_contiguous(values) return x From 7394d732adcec5a67034926a28080af2a7853217 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 15 Aug 2022 19:16:49 -0700 Subject: [PATCH 167/215] [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) --- python/tutorials/06-fused-attention.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index fb0f4f958..035514746 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -204,13 +204,16 @@ class _attention(torch.autograd.Function): def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( q, k, v, sm_scale, tmp, L, m, @@ -221,14 +224,14 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=64, num_warps=4, + BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = 64 + ctx.BLOCK_DMODEL = Lk return o @staticmethod @@ -245,6 +248,8 @@ class _attention(torch.autograd.Function): do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) + + num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, @@ -257,7 +262,7 @@ class _attention(torch.autograd.Function): q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) return dq, dk, dv, None From fe0c29b9ece6e9f88f4d4ed9612c3d591416f821 Mon Sep 17 00:00:00 2001 From: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com> Date: Sat, 27 Aug 2022 02:20:21 +0300 Subject: [PATCH 168/215] Fix inconsistent struct declaration instead of class. (#632) Looks like typo. --- include/triton/codegen/analysis/layout.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index a69687875..313c7b1b3 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -211,7 +211,8 @@ private: TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; }; -struct scanline_layout: public distributed_layout { +class scanline_layout: public distributed_layout { +public: scanline_layout(size_t num_warps, const std::vector& axes, const std::vector& shape, @@ -225,7 +226,7 @@ struct scanline_layout: public distributed_layout { int contig_per_thread(size_t k) { return nts_.at(k); } int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);} -public: +private: // micro tile size. The size of a tile held by a thread block. std::vector mts_; // nano tile size. The size of a tile held by a thread. From 210a296699947ba632f791ab31f8fdb2cfe1c527 Mon Sep 17 00:00:00 2001 From: Da Yan Date: Sat, 27 Aug 2022 11:40:55 +0800 Subject: [PATCH 169/215] [BACKEND] bf16 flash-attention (#636) --- lib/codegen/selection/generator.cc | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d6d059859..d4829e680 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2125,15 +2125,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp16x2_ty; } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { - // FIXME: We should use bf16 here. - mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - smem_ptr_ty = ptr_ty(f16_ty, 3); - ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp16x2_ty; - // mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - // smem_ptr_ty = ptr_ty(bf16_ty, 3); - // ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - // phi_ty = bf16x2_ty; + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(bf16_ty, 3); + ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = bf16x2_ty; } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(fp32_ty, 3); @@ -2233,10 +2228,10 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // std::cout << idxs_[A].size() << std::endl; // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; // int ldm = num_rep_k*4; - Value* ha0 = UndefValue::get(fp16x2_ty); - Value* ha1 = UndefValue::get(fp16x2_ty); - Value* ha2 = UndefValue::get(fp16x2_ty); - Value* ha3 = UndefValue::get(fp16x2_ty); + Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2 + Value* ha1 = UndefValue::get(phi_ty); + Value* ha2 = UndefValue::get(phi_ty); + Value* ha3 = UndefValue::get(phi_ty); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); From 437ced38c2b04d59ce4cee5bf0b9a793b365d29d Mon Sep 17 00:00:00 2001 From: Da Yan Date: Wed, 31 Aug 2022 05:20:12 +0800 Subject: [PATCH 170/215] fp8 <> bf16 conversion (#637) Co-authored-by: Philippe Tillet --- include/triton/codegen/selection/generator.h | 2 + lib/codegen/selection/generator.cc | 117 ++++++++++++++++++- lib/codegen/transform/membar.cc | 3 + python/test/unit/language/test_core.py | 14 ++- python/triton/language/semantic.py | 7 ++ 5 files changed, 133 insertions(+), 10 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 7867c356b..bbf0417f1 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -148,6 +148,8 @@ private: std::tuple fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); + std::tuple fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3); + std::tuple bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); Value* bf16_to_fp32(Value *in0); Value* fp32_to_bf16(Value *in0); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d4829e680..1fe8fed89 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -569,10 +569,10 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) - "shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion) - "shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position) - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign) - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign) + "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 poistion) + "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position) + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign) + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign) "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); packed_in = insert_elt(packed_in, in0, (uint64_t)0); @@ -635,6 +635,110 @@ std::tuple generator::fp16x4_to_fp8x4(Value *in0 return std::make_tuple(ret0, ret1, ret2, ret3); } +std::tuple generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) { + // current exp offset: 15 + // Add 112 (127-15) to compensate the difference in exponent bias + // bf16 = (nosign >> (8-4) + 112 << 7) | sign; + // bf16 = (nosign >> 4 + 0x3800) | sign; + Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), + "{" + ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t" + "prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0 + "prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0 + "and.b32 sign0, a0, 0x80008000; \n\t" + "and.b32 sign1, a1, 0x80008000; \n\t" + "and.b32 nosign0, a0, 0x7fff7fff; \n\t" + "and.b32 nosign1, a1, 0x7fff7fff; \n\t" + "shr.b32 nosign0, nosign0, 4; \n\t" + "shr.b32 nosign1, nosign1, 4; \n\t" + "add.u32 nosign0, nosign0, 0x38003800; \n\t" + "add.u32 nosign1, nosign1, 0x38003800; \n\t" + "or.b32 $0, sign0, nosign0; \n\t" + "or.b32 $1, sign1, nosign1; \n\t" + "}", "=r,=r,r", false); + Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); + packed_in = insert_elt(packed_in, in0, (uint64_t)0); + packed_in = insert_elt(packed_in, in1, (uint64_t)1); + packed_in = insert_elt(packed_in, in2, (uint64_t)2); + packed_in = insert_elt(packed_in, in3, (uint64_t)3); + Value *in = bit_cast(packed_in, i32_ty); + Value *ret = call(ptx, {in}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); + Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); + Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); + Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + +std::tuple generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { + /* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127. + Max value in fp8: 0b01111111 (0x7f), + bf16: 3ff0 + Min value in fp8: 0b00000000 (0x00) + bf16: 0x3c00 + // @note: +0x8 is for "rounding to nearest zero" + fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4; + return fp8 | sign; // also permute bytes + */ + InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), + "{\n\t" + ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t" + ".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t" + "mov.u32 fp8_min, 0x38003800; \n\t" + "mov.u32 fp8_max, 0x3ff03ff0; \n\t" + "mov.u32 rn_, 0x80008; \n\t" + "mov.u32 zero, 0; \n\t" + "and.b32 sign0, $1, 0x80008000; \n\t" + "and.b32 sign1, $2, 0x80008000; \n\t" + "prmt.b32 sign, sign0, sign1, 0x7531; \n\t" + "and.b32 nosign0, $1, 0x7fff7fff; \n\t" + "and.b32 nosign1, $2, 0x7fff7fff; \n\t" + + ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max) + "and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t" + "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t" + "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t" + "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t" + "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t" + "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t" + "or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t" + "and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t" + "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t" + "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t" + "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t" + "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t" + "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t" + "or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t" + + "add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero + "add.u32 nosign1, nosign1, rn_; \n\t" + "sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset + "sub.u32 nosign1, nosign1, 0x38003800; \n\t" + "shr.u32 nosign0, nosign0, 4; \n\t" + "shr.u32 nosign1, nosign1, 4; \n\t" + "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t" + "or.b32 $0, nosign, sign; \n\t" + "" + "}", "=r,r,r", false); + Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2)); + Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2)); + packed_in0 = insert_elt(packed_in0, in0, (int)0); + packed_in0 = insert_elt(packed_in0, in1, (int)1); + packed_in1 = insert_elt(packed_in1, in2, (int)0); + packed_in1 = insert_elt(packed_in1, in3, (int)1); + Value *in_arg0 = bit_cast(packed_in0, i32_ty); + Value *in_arg1 = bit_cast(packed_in1, i32_ty); + Value *ret = call(ptx, {in_arg0, in_arg1}); + Value *ret0 = extract_elt(ret, (int)0); + Value *ret1 = extract_elt(ret, (int)1); + Value *ret2 = extract_elt(ret, (int)2); + Value *ret3 = extract_elt(ret, (int)3); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + Value* generator::bf16_to_fp32(Value *in0){ if (tgt_->as_nvidia()->sm() >= 80) { InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), @@ -685,6 +789,11 @@ void generator::visit_cast_inst(ir::cast_inst* x) { return fp8x4_to_fp16x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) return fp8x4_to_fp32x4(a, b, c, d); + // fp8 <> bf16 + if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty()) + return fp8x4_to_bf16x4(a, b, c, d); + if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty()) + return bf16x4_to_fp8x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); }; for(size_t i = 0; i < x_idxs.size(); i+=4){ diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 22fe00fe6..224f44e9b 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector &async_write) { else{ if(layouts_->has_tmp(v)) return async_write.size() - 1; + // // Ignore copy_to_shared. It won't modify async behavior. + // if(dynamic_cast(v)) + // return 0; auto it = std::find(async_write.begin(), async_write.end(), v); return std::distance(async_write.begin(), it); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2eadf34a2..13a3d6314 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -719,8 +719,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): assert to_numpy(z_tri) == z_ref -def test_f8_f16_roundtrip(): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_f8_xf16_roundtrip(dtype): """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" + check_type_supported(dtype) + @triton.jit def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -732,13 +735,13 @@ def test_f8_f16_roundtrip(): f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') f8 = triton.reinterpret(f8_tensor, tl.float8) n_elements = f8_tensor.numel() - f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + xf16 = torch.empty_like(f8_tensor, dtype=dtype) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024) + copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024) - f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) + f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) + copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -746,7 +749,6 @@ def test_f8_f16_roundtrip(): def test_f16_to_f8_rounding(): """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute error is the minimum over all float8. - Or the same explanation a bit mathier: for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" @triton.jit diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a860dc763..851d7ba6a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -581,6 +581,13 @@ def cast(input: tl.tensor, return input src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar + # fp8 <=> bf16/fp16 + if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8(): + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), + dst_ty) + if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()): + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), + dst_ty) # bf16 <=> (not fp32) if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): From 59a8e25f438c0836dbae7f4bd5bc0a746bf22560 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 14 Sep 2022 12:17:05 -0700 Subject: [PATCH 171/215] [DOCS] Fix typo (#650) --- python/tutorials/03-matrix-multiplication.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 231d3371c..7b2a35bd2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -76,8 +76,8 @@ You will specifically learn about: # # .. code-block:: python # -# pa += BLOCK_SIZE_K * stride_ak; -# pb += BLOCK_SIZE_K * stride_bk; +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; # # # L2 Cache Optimizations From cfbbc7b43a685026ef35285ab2a07833462e2e01 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 14 Sep 2022 13:47:50 -0700 Subject: [PATCH 172/215] [CI] Added V100 tag to disambiguate self-hosted runners (#653) --- .github/workflows/integration-tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f78091568..8d9ec237f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,14 +5,13 @@ on: pull_request: branches: - master - - v2.0 jobs: Integration-Tests: - runs-on: self-hosted + runs-on: [self-hosted, V100] steps: From 4580a047106fb524215d203052666524be40762e Mon Sep 17 00:00:00 2001 From: Sophia Wisdom Date: Wed, 14 Sep 2022 14:26:42 -0700 Subject: [PATCH 173/215] [FRONTEND] Improve error message for CPU tensors (#654) Redo of #651 against master. Fixes #525 by catching CUDA error when we check pytorch tensor size and rethrowing a more informative error that says why we failed. --- python/src/triton.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index a606c051a..d56ff8430 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -226,11 +226,16 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f // copy param std::memcpy(params_ptr, &value, 8); params_ptr += 8; - // udpate cache key + // update cache key cache_key += dtype_cache_key_part(arg.attr("dtype")); cache_key += "*"; cache_key += "[multipleof("; - size_t range_size = get_pointer_range_size(value); + size_t range_size; + try { + range_size = get_pointer_range_size(value); + } catch (...) { + throw std::runtime_error("argument tensor #" + std::to_string(i) + " is not on cuda! " + std::string(py::str(arg))); + } cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size))); cache_key += ")]"; continue; From c668d6596e45b7d5134f835c257b0ea627c2be0f Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Fri, 16 Sep 2022 12:26:40 -0700 Subject: [PATCH 174/215] [DOCS] Fix spelling (#664) This PR applies minor spelling fix in comments and string literals to `master`. It shouldn't hurt anything. --- cmake/FindLLVM.cmake | 4 ++-- docs/programming-guide/chapter-2/related-work.rst | 2 +- include/triton/codegen/analysis/layout.h | 2 +- lib/codegen/selection/generator.cc | 2 +- lib/driver/llvm.cc | 2 +- lib/ir/print.cc | 2 +- python/src/functions.h | 2 +- python/test/regression/test_performance.py | 2 +- python/test/unit/language/test_core.py | 4 ++-- python/triton/code_gen.py | 2 +- python/triton/language/core.py | 4 ++-- python/triton/ops/blocksparse/matmul.py | 2 +- python/triton/testing.py | 2 +- python/triton/tools/disasm.py | 2 +- python/tutorials/02-fused-softmax.py | 2 +- python/tutorials/03-matrix-multiplication.py | 2 +- 16 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cmake/FindLLVM.cmake b/cmake/FindLLVM.cmake index b615936e6..f9216a24e 100644 --- a/cmake/FindLLVM.cmake +++ b/cmake/FindLLVM.cmake @@ -25,7 +25,7 @@ # LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). # LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0). # -# Note: The variable names were chosen in conformance with the offical CMake +# Note: The variable names were chosen in conformance with the official CMake # guidelines, see ${CMAKE_ROOT}/Modules/readme.txt. # Try suffixed versions to pick up the newest LLVM install available on Debian @@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs) find_package_handle_standard_args(LLVM REQUIRED_VARS LLVM_ROOT_DIR - VERSION_VAR LLVM_VERSION_STRING) \ No newline at end of file + VERSION_VAR LLVM_VERSION_STRING) diff --git a/docs/programming-guide/chapter-2/related-work.rst b/docs/programming-guide/chapter-2/related-work.rst index bb83d4851..e21ec4de7 100644 --- a/docs/programming-guide/chapter-2/related-work.rst +++ b/docs/programming-guide/chapter-2/related-work.rst @@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL Program Representation +++++++++++++++++++++++ -Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming. +Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming. .. table:: :widths: 50 50 diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 313c7b1b3..39d40511e 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -246,7 +246,7 @@ struct N_buffer_info_t { std::map firsts_idx; }; -// abstract for dot and coresponding smem values +// abstract for dot and corresponding smem values class shared_layout: public data_layout { private: static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 1fe8fed89..526c64b47 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -569,7 +569,7 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) - "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 poistion) + "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 position) "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position) "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign) "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index c4a13b806..e17c381cb 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -96,7 +96,7 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s std::string path_to_ptxas(int& version) { std::vector rets; std::string ret; - // search pathes for ptxas + // search paths for ptxas std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); if(!triton_ptxas.empty()) diff --git a/lib/ir/print.cc b/lib/ir/print.cc index db73ec7d9..4b6e3266f 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -92,7 +92,7 @@ public: //------------------------- void SlotTracker::process_module() { // Nothing to do at the moment. - // Create slots for global variable & unamed functions & ... + // Create slots for global variable & unnamed functions & ... module_processed = true; } diff --git a/python/src/functions.h b/python/src/functions.h index d5b6c15ef..e27941a0b 100644 --- a/python/src/functions.h +++ b/python/src/functions.h @@ -253,7 +253,7 @@ ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { std::string where_docstr = R"pbdoc( Returns a block of elements from either `x` or `y`, depending on `condition`. Note that `x` and `y` are always evaluated regardless of the value of `condition`. - If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. + If you want to avoid unintended memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. :param condition: When True (nonzero), yield x, otherwise yield y. :type condition: Block of triton.bool diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index f30b203bb..16811eaa9 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -152,7 +152,7 @@ def test_elementwise(N): cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = mem_clocks[DEVICE_NAME] max_gpu_perf = get_dram_gbps() - assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' + assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 13a3d6314..46ddfe760 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -55,7 +55,7 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]: ''' - Note: We need dst_type becasue the type of x can be different from dst_type. + Note: We need dst_type because the type of x can be different from dst_type. For example: x is of type `float32`, dst_type is `bfloat16`. If dst_type is None, we infer dst_type from x. ''' @@ -424,7 +424,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): # ---------------- # test math ops # ---------------- -# @pytest.mark.paramterize("expr", [ +# @pytest.mark.parametrize("expr", [ # 'exp', 'log', 'cos', 'sin' # ]) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3951d8b6b..e2956aea9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -57,7 +57,7 @@ def mangle_ty(ty): elt = mangle_ty(ty.scalar) shape = '_'.join(map(str, ty.shape)) return f'{elt}S{shape}S' - assert False, "Unsupport type" + assert False, "Unsupported type" def mangle_fn(name, arg_tys, constants): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 29a128321..e52a488b2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -192,7 +192,7 @@ class dtype: return builder.get_float_ty() elif self.name == 'fp64': return builder.get_double_ty() - raise ValueError(f'fail to covert {self} to ir type') + raise ValueError(f'fail to convert {self} to ir type') def __str__(self): return self.name @@ -925,7 +925,7 @@ def where(condition, x, y, _builder=None): Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. - If you want to avoid unintented memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. :code:`x` and :code:`y` must have the data type. diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 0fa1a5878..4b6d98aac 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -328,7 +328,7 @@ def dsd_lut(layout, block, step, trans, device): # create increments incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() # pad by a factor 2*MAX_NUM_STAGES - # to accomodate pre-fetching inside the kernel + # to accommodate pre-fetching inside the kernel pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) incs = torch.cat((incs, pad)) # create lut diff --git a/python/triton/testing.py b/python/triton/testing.py index bfcd6ef6b..594edcbf2 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -379,7 +379,7 @@ def cuda_memcheck(**target_kwargs): test_id = kwargs['request'].node.callspec.id cmd = f"{path}::{test_fn.__name__}[{test_id}]" out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) - assert out.returncode == 0, "cuda-memcheck returned an error: bounds checkng failed" + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" assert "ERROR SUMMARY: 0 errors" in str(out.stdout) else: test_fn(*args, **kwargs) diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 3672d4b05..24a0787c5 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -104,7 +104,7 @@ def extract(file_path, fun): # peek the next line line = sass_lines[line_idx].decode() # Print sass - # label naming convension: LBB#i + # label naming convention: LBB#i for idx, (ctrl, asm) in enumerate(asm_buffer): # Print label if this is BRA target offset = idx * 16 diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 7af24e18d..7447b60af 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -78,7 +78,7 @@ def softmax_kernel( input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) - # Substract maximum for numerical stability + # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 7b2a35bd2..49382aecd 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -18,7 +18,7 @@ You will specifically learn about: # They are notoriously hard to optimize, hence their implementation is generally done by # hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). # Unfortunately, these libraries are often proprietary and cannot be easily customized -# to accomodate the needs of modern deep learning workloads (e.g., fused activation functions). +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). # In this tutorial, you will learn how to implement efficient matrix multiplications by # yourself with Triton, in a way that is easy to customize and extend. # From 889d9e34a114b1fe2e8871d21e713794344d12d3 Mon Sep 17 00:00:00 2001 From: Ian Bearman Date: Sat, 17 Sep 2022 14:25:28 -0700 Subject: [PATCH 175/215] [REPO] update gitignore (#666) Update `.gitignore` to include `.vs` and `.vscode` --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index b32df68cc..95d59cfed 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so + +.vscode +.vs \ No newline at end of file From 4a77dfb042ad6c4d61376439962220d86ee6681a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 18 Sep 2022 08:51:48 -0700 Subject: [PATCH 176/215] [FRONTEND] Complete rewrite of the runtime (#644) This PR completely rewrites the runtime of Triton to be more lean and clearly separate the compilation step from the just-in-time caching logic. This should substantially reduce launch overhead. --- include/triton/codegen/extern_lib.h | 1 + include/triton/ir/module.h | 1 - lib/codegen/pass.cc | 2 +- python/src/triton.cc | 16 +- python/test/unit/language/test_core.py | 18 +- python/test/unit/operators/test_matmul.py | 7 +- python/test/unit/runtime/test_cache.py | 18 +- python/triton/__init__.py | 7 +- python/triton/{code_gen.py => compiler.py} | 1221 ++++++++------------ python/triton/language/core.py | 2 +- python/triton/ops/blocksparse/softmax.py | 8 +- python/triton/ops/matmul.py | 6 +- python/triton/runtime/__init__.py | 2 + python/triton/runtime/autotuner.py | 204 ++++ python/triton/runtime/jit.py | 415 +++++++ python/triton/testing.py | 2 +- python/triton/utils.py | 48 + 17 files changed, 1198 insertions(+), 780 deletions(-) rename python/triton/{code_gen.py => compiler.py} (50%) create mode 100644 python/triton/runtime/__init__.py create mode 100644 python/triton/runtime/autotuner.py create mode 100644 python/triton/runtime/jit.py create mode 100644 python/triton/utils.py diff --git a/include/triton/codegen/extern_lib.h b/include/triton/codegen/extern_lib.h index c161ff142..02e991407 100644 --- a/include/triton/codegen/extern_lib.h +++ b/include/triton/codegen/extern_lib.h @@ -3,6 +3,7 @@ #include #include +#include #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 1ed0b6646..d09a51a22 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -87,7 +87,6 @@ public: // Functions const functions_list_t &get_function_list() const { return functions_; } - functions_list_t &get_function_list() { return functions_; } function *get_function(const std::string& name) { if(symbols_.find(name) == symbols_.end()) throw std::runtime_error("function " + name + " is not declared"); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 645f10978..024a838d9 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -106,11 +106,11 @@ std::unique_ptr add_passes_to_emit_bin( // run passes inliner.run(ir); dce.run(ir); - // ir.print(std::cout); peephole.run(ir); dce.run(ir); pipeline.run(ir); dce.run(ir); + // ir.print(std::cout); disassociate.run(ir); dce.run(ir); align.run(ir); diff --git a/python/src/triton.cc b/python/src/triton.cc index d56ff8430..8bfb076c3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) { assert(backend == ROCM); return hip_load_binary(name, asm_map, n_shared_bytes, dev); }, py::return_value_policy::take_ownership); + + + struct InstanceDescriptor + { + std::unordered_set divisibleBy16; + std::unordered_set equalTo1; + }; + + py::class_(m, "instance_descriptor") + .def(py::init<>()) + .def(py::init, std::unordered_set>()) + .def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16) + .def_readonly("equal_to_1", &InstanceDescriptor::equalTo1); } @@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) { .def("get", &ir::struct_type::get, ret::reference) .def_property_readonly("num_types", &ir::struct_type::get_num_types); - py::class_(m, "module") + py::class_(m, "module", py::dynamic_attr()) .def(py::init()) .def("has_function", &ir::module::has_function) .def("get_function", &ir::module::get_function, ret::reference) + .def("get_functions", &ir::module::get_function_list, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) .def("print", [](ir::module *self) { self->print(std::cout); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 46ddfe760..d00de5be5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -11,7 +11,7 @@ from numpy.random import RandomState import triton import triton._C.libtriton.triton as _triton import triton.language as tl -from triton.code_gen import JITFunction, TensorWrapper, reinterpret +from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or (dtype_x in uint_dtypes and dtype_y in int_dtypes))): - with pytest.raises(triton.code_gen.CompilationError) as exc_info: + with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: @@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): else: numpy_expr = None if 'float' in dtype_x + dtype_y: - with pytest.raises(triton.code_gen.CompilationError) as exc_info: + with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) @@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'): def catch_compilation_error(kernel): try: kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) - except triton.code_gen.CompilationError as e: + except triton.CompilationError as e: np.testing.assert_(True) except BaseException: np.testing.assert_(False) @@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache): assert 'ld.global.cg' not in ptx -@pytest.mark.parametrize("N", [8, 10, 11, 1024]) +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) def test_vectorization(N): src = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda') @@ -1221,10 +1221,8 @@ def test_vectorization(N): tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) ptx = pgm.asm["ptx"] - if N % 4 == 0: + if N % 16 == 0: assert "ld.global.v4.b32" in ptx - elif N % 2 == 0: - assert "ld.global.v2.b32" in ptx else: assert "ld.global.b32" in ptx # triton.testing.assert_almost_equal(dst, src[:N]) @@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non def cache_hook(*args, **kwargs): nonlocal spec_type - spec_type = kwargs["compile"]["arg_types"][0][1] + spec_type = kwargs["compile"]["signature"][0] JITFunction.cache_hook = cache_hook @triton.jit @@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda' x = torch.tensor([3.14159], device='cuda') if overflow: - with pytest.raises(RuntimeError, match='integer overflow'): + with pytest.raises(OverflowError): kernel[(1, )](value, x) else: kernel[(1, )](value, x) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 514fbab7b..e14ea6ae7 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] kernel = triton.ops._matmul.kernel - decorators = kernel.kernel_decorators - kernel.kernel_decorators = [] - triton.autotune(configs, [])(kernel) - kernel.kernel_decorators += decorators[1:] + kernel.configs = configs + # kernel.run = kernel.run.run.run + # get matrix shape M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index fd95dbd38..6fad3af3d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -7,7 +7,7 @@ import torch import triton import triton.language as tl -from triton.code_gen import JITFunction +from triton.runtime.jit import JITFunction tmpdir = ".tmp" @@ -99,16 +99,16 @@ def test_specialize(mode): reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') function = {'enable': kernel, 'disable': kernel_nospec}[mode] - target = {'enable': 5, 'disable': 1}[mode] + target = {'enable': 3, 'disable': 1}[mode] for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target @pytest.mark.parametrize("value, value_type", [ - (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), - (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), - (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') + (-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64') ]) def test_value_specialization(value: int, value_type: str, device='cuda') -> None: @@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non def get_cache_str(*args, **kwargs): nonlocal cache_str - cache_str = kwargs['key'].split('-') - triton.code_gen.JITFunction.cache_hook = get_cache_str + cache_str = kwargs["repr"] + triton.JITFunction.cache_hook = get_cache_str reset_tmp_dir() x = torch.tensor([3.14159], device='cuda') kernel[(1, )](value, x) - triton.code_gen.JITFunction.cache_hook = None + triton.JITFunction.cache_hook = None - cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) + cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str) spec_type = None if cache_str_match is None else cache_str_match.group(1) assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 37ba46efc..c620543ee 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,9 +6,10 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - JITFunction, Config, Autotuner, reinterpret +from .utils import * +from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface +from .runtime.jit import jit +from .compiler import compile, CompilationError from . import language -from . import code_gen from . import testing from . import ops diff --git a/python/triton/code_gen.py b/python/triton/compiler.py similarity index 50% rename from python/triton/code_gen.py rename to python/triton/compiler.py index e2956aea9..98ccc6b1a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/compiler.py @@ -1,39 +1,48 @@ from __future__ import annotations import ast -import builtins +import contextlib import functools import hashlib -import inspect +import io import os -import pickle import subprocess import sys +import sysconfig import tempfile -import textwrap -import threading -import time import warnings -from typing import Dict, Set, Tuple, Union +from typing import Any, Dict, Set, Tuple, Union +import setuptools import torch from filelock import FileLock import triton import triton._C.libtriton.triton as _triton -from .tools.disasm import extract - -try: - from torch._C import _cuda_getCurrentRawStream as get_cuda_stream -except ImportError: - get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream -def current_cuda_stream(device_idx=0): - # Torch's torch.cuda.current_stream() is slow. We provide this - # function to give the user an opportunity to monkey-patch their - # own faster current stream lookup. - return get_cuda_stream(device_idx) +def str_to_ty(name): + if name[0] == "*": + ty = str_to_ty(name[1:]) + return triton.language.pointer_type(ty) + tys = { + "i1": triton.language.int1, + "fp8": triton.language.float8, + "fp16": triton.language.float16, + "bf16": triton.language.bfloat16, + "fp32": triton.language.float32, + "fp64": triton.language.float64, + "i8": triton.language.int8, + "i16": triton.language.int16, + "i32": triton.language.int32, + "i64": triton.language.int64, + "u8": triton.language.uint8, + "u16": triton.language.uint16, + "u32": triton.language.uint32, + "u64": triton.language.uint64, + "B": triton.language.int1, + } + return tys[name] def mangle_ty(ty): @@ -63,7 +72,7 @@ def mangle_ty(ty): def mangle_fn(name, arg_tys, constants): # doesn't mangle ret type, which must be a function of arg tys mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) + key = lambda x: x.__name__ if isinstance(x, triton.runtime.JITFunction) else repr(x) mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') @@ -218,7 +227,8 @@ class ValueConstructor: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, prototypes=None, module=None, is_kernel=False): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, spec_to_1=None, prototypes=None, module=None, is_kernel=False): + self.spec_to_1 = set() if spec_to_1 is None else spec_to_1 self.prototypes = dict() if prototypes is None else prototypes self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) if module is None else module @@ -226,6 +236,7 @@ class CodeGenerator(ast.NodeVisitor): self.attributes = attributes self.constants = constants self.last_node = None + self.function_name = function_name self.is_kernel = is_kernel self.value_constructor = ValueConstructor(self.module, self.builder, gscope) @@ -260,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor): return ret def visit_FunctionDef(self, node): - arg_names, kwarg_names = self.visit(node.args) + arg_names, arg_annotations, kwarg_names = self.visit(node.args) # initialize defaults for i, default_value in enumerate(node.args.defaults): arg_node = node.args.args[-i - 1] @@ -273,28 +284,27 @@ class CodeGenerator(ast.NodeVisitor): init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) # initialize function - fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) - self.prototypes[fn_name] = self.prototype - fn = self.module.get_or_insert_function(fn_name, self.prototype.to_ir(self.builder)) + self.prototypes[self.function_name] = self.prototype + fn = self.module.get_or_insert_function(self.function_name, self.prototype.to_ir(self.builder)) fn.set_is_kernel(self.is_kernel) arg_values = [] idx = 0 - for i, arg_name in enumerate(arg_names): + for i, (arg_name, annotation) in enumerate(zip(arg_names, arg_annotations)): if i in self.constants: cst = self.constants[i] if not isinstance(cst, triton.language.constexpr): cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) - else: - if i in self.attributes: - is_ptr = fn.args[idx].type.is_ptr() - attr = 'aligned' if is_ptr else 'multiple_of' - attr = getattr(_triton.ir.attribute_kind, attr) - attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(idx + 1, attr) - fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) - idx += 1 + continue + if i in self.attributes: + is_ptr = fn.args[idx].type.is_ptr() + attr = 'aligned' if is_ptr else 'multiple_of' + attr = getattr(_triton.ir.attribute_kind, attr) + attr = _triton.ir.attribute(attr, self.attributes[i][1]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) + idx += 1 insert_pt = self.builder.get_insert_block() entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) @@ -309,20 +319,23 @@ class CodeGenerator(ast.NodeVisitor): self.builder.ret_void() else: # a bit hacky: we only know the return type at the last moment so we update type info here - self.module.reset_ret_ty(fn_name, self.last_ret.type.to_ir(self.builder)) + self.module.reset_ret_ty(self.function_name, self.last_ret.type.to_ir(self.builder)) self.prototype.ret_type = self.last_ret.type self.builder.set_insert_block(insert_pt) def visit_arguments(self, node): arg_names = [] + arg_annotations = [] for arg in node.args: - arg_names += [self.visit(arg)] + curr = self.visit(arg) + arg_names += [curr[0]] + arg_annotations += [curr[1]] kwarg_names = self.visit(node.kwarg) - return arg_names, kwarg_names + return arg_names, arg_annotations, kwarg_names def visit_arg(self, node): ast.NodeVisitor.generic_visit(self, node) - return node.arg + return node.arg, node.annotation def visit_AnnAssign(self, node): # extract attributes @@ -661,7 +674,7 @@ class CodeGenerator(ast.NodeVisitor): kws.update(self.visit(keyword)) args = [self.visit(arg) for arg in node.args] - if isinstance(fn, JITFunction): + if isinstance(fn, triton.runtime.JITFunction): from inspect import getcallargs args = getcallargs(fn.fn, *args, **kws) args = [args[name] for name in fn.arg_names] @@ -681,7 +694,7 @@ class CodeGenerator(ast.NodeVisitor): ret_type = triton.language.void prototype = triton.language.function_type(ret_type, arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, prototypes=self.prototypes, module=self.module) + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, function_name=fn_name, prototypes=self.prototypes, module=self.module) generator.visit(fn.parse()) symbol = self.module.get_function(fn_name) ret = self.builder.call(symbol, arg_vals) @@ -758,52 +771,6 @@ class CodeGenerator(ast.NodeVisitor): raise NotImplementedError("Unsupported node: {}".format(typename)) -class Binary: - def __init__(self, backend, name, asm, shared_mem, num_warps): - self.backend = backend - self.name = name - self.asm = asm - self.shared_mem = shared_mem - self.num_warps = num_warps - - -class LoadedBinary: - def __init__(self, device: int, bin: Binary): - module, kernel, n_regs, n_spills = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) - self.bin = bin - self.asm = bin.asm - self.sass = '' - self.module = module - self.kernel = kernel - self.n_regs = n_regs - self.n_spills = n_spills - self.device = device - self.shared_mem = bin.shared_mem - - def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): - _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, - args, self.bin.shared_mem) - - def get_sass(self, fun=None): - if self.sass: - return self.sass - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass - - class CompilationError(Exception): def __init__(self, src, node): self.message = f'at {node.lineno}:{node.col_offset}:\n' @@ -833,694 +800,464 @@ class OutOfResources(Exception): return (type(self), (self.required, self.limit, self.name)) -class Kernel: - - @staticmethod - def _type_name(obj): - type_names = { - triton.language.float8: 'f8', - torch.bfloat16: 'bf16', - torch.float16: 'f16', - torch.float32: 'f32', - torch.float64: 'f64', - torch.bool: 'i1', - torch.uint8: 'u8', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - triton.language.uint8: 'u8', - triton.language.uint16: 'u16', - triton.language.uint32: 'u32', - triton.language.uint64: 'u64', - } - if hasattr(obj, 'data_ptr'): - return type_names[obj.dtype] - if isinstance(obj, triton.language.constexpr): - obj = obj.value - if isinstance(obj, int): - if -2**31 <= obj < 2**31: - return 'i32' - elif 2**31 <= obj < 2**32: - return 'u32' - elif -2**63 <= obj < 2**63: - return 'i64' - elif 2**63 <= obj < 2**64: - return 'u64' - else: - raise ValueError(f'integer overflow representing {obj}') - if isinstance(obj, float): - return 'f' - if isinstance(obj, bool): - return 'B' - if isinstance(obj, str): - return 'str' - raise NotImplementedError(f'could not compute type name for {obj}') - - @staticmethod - def _to_python_ir(obj): - # convert torch.Tensor to Triton IR pointers - if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj) - return 'ptr', name - # default path returns triton.ir.type directly - name = Kernel._type_name(obj) - return 'scalar', name - - @staticmethod - def _to_triton_ir(obj): - which, name = obj - type_map = { - 'I': triton.language.int32, - 'L': triton.language.int64, - 'f': triton.language.float32, - 'B': triton.language.int1, - 'f8': triton.language.float8, - 'f16': triton.language.float16, - 'bf16': triton.language.bfloat16, - 'f32': triton.language.float32, - 'f64': triton.language.float64, - 'i1': triton.language.int1, - 'i8': triton.language.int8, - 'i16': triton.language.int16, - 'i32': triton.language.int32, - 'i64': triton.language.int64, - 'u8': triton.language.uint8, - 'u16': triton.language.uint16, - 'u32': triton.language.uint32, - 'u64': triton.language.uint64, - } - # convert torch.Tensor to Triton IR pointers - if which == 'ptr': - elt_ty = type_map[name] - return triton.language.pointer_type(elt_ty, 1) - # default path returns triton.ir.type directly - return type_map[name] - - @staticmethod - def pow2_divisor(N): - if N % 16 == 0: - return 16 - if N % 8 == 0: - return 8 - if N % 4 == 0: - return 4 - if N % 2 == 0: - return 2 - return 1 - - def __init__(self, fn): - self.fn = fn - self.cache_key = {} - - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs): - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - - # attributes - attributes = dict() - for i, arg in enumerate(wargs): - if i in self.fn.do_not_specialize: - continue - if isinstance(arg, int): - attributes[i] = Kernel.pow2_divisor(arg) - elif i in tensor_idxs: - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - attributes[i] = min(Kernel.pow2_divisor(addr), - Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, - extern_libs=extern_libs, is_manual_warmup=False) - - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs): - assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." - # handle arguments passed by name - kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} - wargs = list(wargs) - for i, pos in enumerate(sorted(kwargs)): - wargs.insert(pos + i, kwargs[pos]) - if len(wargs) != len(self.fn.arg_names): - raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") - # handle annotations - for pos, _type in self.fn.annotations.items(): - assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" - wargs[pos] = _type(wargs[pos]) - # check that tensors are on GPU. - # for arg in wargs: - # if hasattr(arg, 'data_ptr'): - # assert arg.is_cuda, "All tensors must be on GPU!" - # set device (i.e., make sure torch has the context initialized) - device = torch.cuda.current_device() - # torch creates new thread for backward pass that may have uninitlialized context - # no way to know if this function should or shouldn't initialize the cuda context - # so we're being conservative here - torch.cuda.set_device(device) - if device not in self.cache_key: - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - self.cache_key[device] = self.fn.cache_key + cc - cache_key = self.cache_key[device] - stream = current_cuda_stream(device) - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, - device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache, - grid) +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix -class Launcher: - def __init__(self, kernel, grid): - self.kernel = kernel - self.grid = grid +def make_triton_ir(fn, signature, specialization, constants): + context = _triton.ir.context() + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) + tys = list(signature.values()) + new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1} + new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] - def __call__(self, *wargs, **kwargs): - return self.kernel(*wargs, **kwargs, grid=self.grid) + prototype = triton.language.function_type(triton.language.void, arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True) + try: + generator.visit(fn.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(fn.src, node) from e + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret, generator -class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): - ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - ''' - if not configs: - self.configs = [Config(dict(), num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.cache = dict() - self.kernel = kernel - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - return triton.testing.do_bench(kernel_call) - - def __call__(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple([args[i] for i in self.key_idx]) - if key not in self.cache: - # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) +def make_ptx(mod: Any, device: int) -> Tuple[str, int]: + ''' + Translate TritonGPU module to PTX code. + :param mod: a TritonGPU dialect module + :return: + - PTX code + - shared memory alloaction size + ''' + return _triton.translate_triton_gpu_to_ptx(mod, device) -_version_key_lock = threading.Lock() -_version_key = None +def make_cubin(ptx, device): + ''' + Compile TritonGPU module to cubin. + :param ptx: ptx code + :param device: CUDA device + :return: str + ''' + return _triton.compile_ptx_to_cubin(ptx, device) -def version_key(): - global _version_key - - if _version_key is not None: - return _version_key - - with _version_key_lock: - if _version_key is not None: - return _version_key - - import pkgutil - contents = [] - # frontend - with open(triton.code_gen.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # backend - with open(triton._C.libtriton.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # language - language_path = os.path.join(*triton.__path__, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # ptxas version - try: - ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() - except Exception: - ptxas_version = '' - _version_key = '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) - return _version_key +def ptx_get_kernel_name(ptx: str) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert ptx + for line in ptx.split('\n'): + line = line.strip() + if line.startswith('// .globl'): + return line.split()[-1] -class DependenciesFinder(ast.NodeVisitor): +def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]: + valid_outputs = ("ttir", "ttgir", "ptx", "cubin") + assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) - def __init__(self, globals, src) -> None: - super().__init__() - self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() - self.globals = globals + # triton-ir + module, _ = make_triton_ir(fn, signature, specialization, constants) + if output == "ttir": + return module - def visit_Name(self, node): - return self.globals.get(node.id, None) + assert output == "cubin" + assert torch.version.hip is None + backend = _triton.runtime.backend.CUDA + if extern_libs is None: + extern_libs = dict() + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs) + return asm, shared_mem, name - def visit_Attribute(self, node): - lhs = self.visit(node.value) - while isinstance(lhs, ast.Attribute): - lhs = self.visit(lhs.value) - if lhs is None or lhs is triton: - return None - return getattr(lhs, node.attr) - def visit_Call(self, node): - func = self.visit(node.func) - if func is None: - return - if inspect.isbuiltin(func): - return - if func.__module__ and func.__module__.startswith('triton.'): - return - assert isinstance(func, triton.JITFunction) - if func.hash is None: - tree = ast.parse(func.src) - finder = DependenciesFinder(func.__globals__, func.src) - finder.visit(tree) - func.hash = finder.ret - self.ret = (self.ret + func.hash).encode("utf-8") - self.ret = hashlib.md5(self.ret).hexdigest() +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp32": "float", + }[ty] + + +def generate_name_initializer(signature): + src = "int i = 0;\n" + tys = signature.split(',') + for i, ty in enumerate(tys): + src + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +@functools.lru_cache() +def libcuda_dir(): + loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1] + return os.path.dirname(loc) + + +def _build(name, src, path): + # add framework + extra_compile_args = [] + library_dirs = [libcuda_dir()] + include_dirs = [path, "/usr/local/cuda/include/"] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c++', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + path) + args.append('--build-lib=' + path) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + # with quiet(): + setuptools.setup(**args) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(path, '{name}{suffix}'.format(name=name, suffix=suffix)) + return so + + +def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, tmpdir): + headers = dict() + + # write all cubins to header files + assert len(binaries) == 1, "AoT compilation not yet supported" + + for bin, shmem_size, name in binaries: + assert len(name) < 1024 + initializer = f""" +const char* {name}_ptx = R"({bin["ptx"]})"; +unsigned char {name}_bin[] = {{ {','.join(map(hex, bin["cubin"]))} }}; +unsigned int {name}_shmem = {shmem_size};""" + headers[name] = os.path.join(tmpdir, f"{name}.h") + with open(headers[name], "w") as f: + f.write(initializer) + + func_init = '\n '.join(f"init_function(\"{name}\", {name}_bin, {name}_shmem, device);" for _, _, name in binaries) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + src = "" + for bin, shmem_size, name in binaries: + src += f"#include \"{name}.h\"\n" + src += f""" +#include \"cuda.h\" +#include + +inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static CUmodule module = 0; +static CUfunction function = 0; + +static void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{ + CUmodule mod; + CUfunction fun; + CUDA_CHECK(cuModuleLoadData(&mod, src)); + CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); + if (n_shared_bytes > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + int n_spills, n_reg; + CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device)); + CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + CUDA_CHECK(cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); + }} + module = mod; + function = fun; +}} + +static void init_module(CUdevice device) {{ + {func_init} +}} + + +void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ + CUcontext ctx; + CUdevice device; + CUDA_CHECK(cuStreamGetCtx(stream, &ctx)); + CUDA_CHECK(cuCtxGetDevice(&device)); + + // TODO: machine may have heterogeneous devices + if(function == 0){{ + init_module(device); + }} + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + if(gridX*gridY*gridZ > 0){{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*{num_warps}, 1, 1, {name}_shmem, stream, params, 0)); + }} +}} + +CUdeviceptr getPointer(PyObject *obj, int idx) {{ + if (PyLong_Check(obj)) {{ + return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj); + }} + if (obj == Py_None) {{ + return (CUdeviceptr)0; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + }} + return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret); + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return (CUdeviceptr)0; +}} + + +static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t stream; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &stream, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + return NULL; + }} + + _{kernel_name}(gridX, gridY, gridZ, (CUstream)stream, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"{kernel_name}", {kernel_name}, METH_VARARGS, "Call {kernel_name} kernel"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"{kernel_name}\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit_{kernel_name}() {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + PyObject *ptx = PyDict_New(); +""" + + for _, _, name in binaries: + src += f""" + PyObject *py_{name}_ptx = PyUnicode_FromString({name}_ptx); + PyDict_SetItemString(ptx, "{name}", py_{name}_ptx); + Py_DECREF(py_{name}_ptx); +""" + + src += """ + PyModule_AddObject(m, "ptx", ptx); + return m; +} +""" + + return src def default_cache_dir(): return os.path.join(os.environ["HOME"], ".triton", "cache") -class JITFunction: +class CacheManager: - cache_hook = None + def __init__(self, key): + self.key = key + self.bin_path = None + self.lock_path = None + # if caching is enabled, get the lock and bin path + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + if self.cache_dir: + os.makedirs(self.cache_dir, exist_ok=True) + if self.cache_dir: + self.bin_path = os.path.join(self.cache_dir, self.key + ".so") + self.lock_path = self.bin_path + ".lock" - def __init__(self, fn, version=None, inline=True, do_not_specialize=None): - # information of wrapped function - self.fn = fn - self.module = fn.__module__ - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - self.arg_defaults = [v.default for v in signature.parameters.values()] + def has_file(self): + return self.bin_path and os.path.exists(self.bin_path) - self.version = version - self.inline = inline - self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def"):] - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] - # cache for callable driver objects (e.g. CUkernel) - self.bin_cache = dict() - self.hash = None - # JITFunction can be instantiated as kernel - # when called with a grid using __getitem__ - self.kernel_decorators = [] - self.kernel = None - # annotations - self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} - self.__annotations__ = fn.__annotations__ - # constexprs - self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] - # forward docs - self.__doc__ = fn.__doc__ - self.__name__ = fn.__name__ - self.__globals__ = fn.__globals__ - self.__module__ = fn.__module__ + def put(self, binary): + if self.bin_path: + assert self.lock_path is not None + with FileLock(self.lock_path): + with open(self.bin_path + ".tmp", "wb") as f: + f.write(binary) + os.rename(self.bin_path + ".tmp", self.bin_path) - @property - @functools.lru_cache() - def cache_key(self): - # TODO : hash should be attribute of `self` - if self.hash is None: - dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) - dependencies_finder.visit(self.parse()) - self.hash = dependencies_finder.ret + version_key() - return self.hash - # we do not parse `src` in the constructor because - # the user might want to monkey-patch self.src dynamically. - # Some unit tests do this, for example. - def parse(self): - tree = ast.parse(self.src) - assert isinstance(tree, ast.Module) - assert len(tree.body) == 1 - assert isinstance(tree.body[0], ast.FunctionDef) - return tree +def make_cache_key(fn, signature, configs, constants, num_warps, num_stages): + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key - def __call__(self, *args, **kwargs): - raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.") - # - when `.src` attribute is set, cache path needs - # to be reinitialized - # - when kernel decorators change, cached kernel - # needs to be cleared - def __setattr__(self, name, value): - if name == 'kernel_decorators': - self.kernel = None - super(JITFunction, self).__setattr__(name, value) - if name == 'src': - self.hash = None - JITFunction.cache_key.fget.cache_clear() +def make_shared_object(fn, constants, signature, num_warps, binaries, tmpdir): + src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + with quiet(): + bin_path = _build(fn.__name__, src_path, tmpdir) + with open(bin_path, "rb") as f: + return f.read() - def _init_kernel(self): - if self.kernel is None: - self.kernel = Kernel(self) - for decorator in reversed(self.kernel_decorators): - self.kernel = decorator(self.kernel) - return self.kernel - def warmup(self, compile): - return self._warmup(**compile, is_manual_warmup=True) +def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): + # we get the kernel, i.e. the first function generated in the module + if configs is None: + assert False, "automatic specialization is not supported yet" + ref, _ = make_triton_ir(fn, signature, _triton.code_gen.instance_descriptor(), constants) + fns = ref.get_functions() + configs = _triton.infer_specialization_configs(fns[0]) + assert len(configs) == 1 + # cache manager + cache_key = make_cache_key(fn, signature, configs, constants, num_warps, num_stages) + cache_manager = CacheManager(cache_key) + # retrieve cached shared object if it exists + if cache_manager.has_file(): + return CompiledKernel(fn.__name__, cache_manager.bin_path) + # compile all the configs + binaries = [] + for config in configs: + binaries.append(_compile(fn, signature, device, constants, config, num_warps, num_stages, extern_libs, "cubin")) + # generate and compile glue code into shared object + with tempfile.TemporaryDirectory() as tmpdir: + all_constants = set(constants.keys()) + all_constants.update(configs[0].equal_to_1) + so = make_shared_object(fn, all_constants, signature, num_warps, binaries, tmpdir) - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup): - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() + # write shared object to cache + cache_manager.put(so) + return CompiledKernel(fn.__name__, cache_manager.bin_path) - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) - if cache_dir: - os.makedirs(cache_dir, exist_ok=True) - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None +class CompiledKernel: - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs) - if JITFunction.cache_hook is not None: - name = self.__name__ - info = key.split('-')[-3:] - num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] - # make signature human-readable - arg_reprs = [] - for arg_name, arg_sig in zip(self.arg_names, sig): - arg_reprs.append(f'{arg_name}: {arg_sig}') - # assemble the repr - arg_reprs = ", ".join(arg_reprs) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" - noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None) - if noop: - return True - - if binary is None: - binary = self._compile(**compile) - - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - - self.bin_cache[key] = LoadedBinary(device, binary) - return False - - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs): - # create IR module - context = _triton.ir.context() - # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) - # generate Triton-IR - # export symbols visible from self into code-generator object - gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) - try: - generator.visit(self.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e - # Compile to machine code - if torch.version.hip is None: - backend = _triton.runtime.backend.CUDA - else: - backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs) - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - if shared_mem > max_shared_memory: - raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, name, asm, shared_mem, num_warps) + def __init__(self, fn_name, data_path): + import importlib.util + spec = importlib.util.spec_from_file_location(fn_name, data_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.c_wrapper = getattr(mod, fn_name) + ptx = getattr(mod, "ptx") + if len(ptx) == 1: + self.asm = {"ptx": list(ptx.values())[0]} def __getitem__(self, grid): - return Launcher(self._init_kernel(), grid) - - def __repr__(self): - return f"JITFunction({self.module}:{self.fn.__name__})" - - -class Config: - """ - An object that represents a possible kernel configuration for the auto-tuner to try. - - :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. - :type meta: dict[Str, Any] - :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if - `num_warps=8`, then each kernel instance will be automatically parallelized to - cooperatively execute using `8 * 32 = 256` threads. - :type num_warps: int - :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. - Mostly useful for matrix multiplication workloads on SM80+ GPUs. - :type num_stages: int - :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this - function are args. - """ - - def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): - self.kwargs = kwargs - self.num_warps = num_warps - self.num_stages = num_stages - self.pre_hook = pre_hook - - def __str__(self): - res = [] - for k, v in self.kwargs.items(): - res.append(f'{k}: {v}') - res.append(f'num_warps: {self.num_warps}') - res.append(f'num_stages: {self.num_stages}') - return ', '.join(res) - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - - .. highlight:: python - .. code-block:: python - - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - def decorator(fn): - def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def heuristics(values): - """ - Decorator for specifying how the values of certain meta-parameters may be computed. - This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - - .. highlight:: python - .. code-block:: python - - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - - - .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. - each such function takes a list of positional arguments as input. - .type values: dict[str, Callable[[list[Any]], Any]] - """ - def decorator(fn): - def wrapper(kernel): - def fun(*args, **meta): - for v, heur in values.items(): - assert v not in meta - meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) - return kernel(*args, **meta) - return fun - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def jit(*args, **kwargs): - """ - Decorator for JIT-compiling a function using the Triton compiler. - - :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. - - :note: This function will be compiled and run on the GPU. It will only have access to: - - * python primitives, - * objects within the triton.language package, - * arguments to this function, - * other jit'd functions - - :param fn: the function to be jit-compiled - :type fn: Callable - """ - if args: - assert len(args) == 1 - assert callable(args[0]) - return JITFunction(args[0], **kwargs) - else: - def decorator(fn): - return JITFunction(fn, **kwargs) - return decorator - -###### - -# class ForwardDeclaration: - -# def __init__(self, name, ret_ty, arg_tys) -> None: -# self.name = name -# self.ret_ty = ret_ty -# self.arg_tys = arg_tys - -# def forward_declare(name, ret_ty, arg_tys): -# return ForwardDeclaration(name, ret_ty, arg_tys) - -###### - - -def cdiv(x, y): - return (x + y - 1) // y - - -def next_power_of_2(n): - """Return the smallest power of 2 greater than or equal to n""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - -###### - - -class TensorWrapper: - def __init__(self, base, dtype): - self.dtype = dtype - self.base = base - self.is_cuda = base.is_cuda - self.device = base.device - - def data_ptr(self): - return self.base.data_ptr() - - def __str__(self) -> str: - return f'TensorWrapper[{self.dtype}]({self.base})' - - -def reinterpret(tensor, dtype): - if isinstance(tensor, TensorWrapper): - if dtype == tensor.base.dtype: - # Reinterpreting to the original interpretation; return the base. - return tensor.base - else: - # Reinterpreting a wrapped tensor to a different type. - return TensorWrapper(tensor.base, dtype) - elif isinstance(tensor, torch.Tensor): - # A new wrapper is needed around an unwrapped tensor. - return TensorWrapper(tensor, dtype) - else: - raise TypeError(f'Cannot reinterpret a {type(tensor)}.') + def runner(*args, stream=None): + if stream is None: + stream = torch.cuda.current_stream().cuda_stream + self.c_wrapper(grid[0], grid[1], grid[2], stream, *args) + return runner diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e52a488b2..63a9ab7f2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -420,7 +420,7 @@ class tensor: self.numel = 1 for s in self.shape: self.numel *= s - is_pow2 = (self.numel and (not(self.numel & (self.numel - 1)))) + is_pow2 = (self.numel and (not (self.numel & (self.numel - 1)))) if not is_pow2: raise ValueError("Triton tensors must have a power-of-two number of elements") self.numel = constexpr(self.numel) diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index bb915be13..33223b72d 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -18,8 +18,8 @@ def num_warps(n): @triton.jit def _blocksparse_softmax_fwd( - Out, A, stride_xz, LUT, - R, extent, stride_zr, stride_hr, # relative attention + Out, A, LUT, R, stride_xz, + extent, stride_zr, stride_hr, # relative attention scale, is_causal, ROW_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function): # enqueue kernel out = torch.empty_like(a) _blocksparse_softmax_fwd[grid]( - out, a, a.stride(0), lut, - rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn + out, a, lut, rel_logits, a.stride(0), + rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn scale, is_causal, BLOCK_SIZE=block, diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index f1ac78849..0ffcc1677 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -26,9 +26,6 @@ def get_configs_io_bound(): return configs -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, -}) @triton.autotune( configs=[ # basic configs for compute-bound matmuls @@ -59,6 +56,9 @@ def get_configs_io_bound(): 'top_k': 10 }, ) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) @triton.jit def _kernel(A, B, C, M, N, K, stride_am, stride_ak, diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py new file mode 100644 index 000000000..d9946c27c --- /dev/null +++ b/python/triton/runtime/__init__.py @@ -0,0 +1,2 @@ +from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401 +from .jit import JITFunction, KernelInterface, version_key # noqa: F401 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..2175501b6 --- /dev/null +++ b/python/triton/runtime/autotuner.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import builtins +import time +from typing import Dict + +from ..testing import do_bench +from .jit import KernelInterface + + +class Autotuner(KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + ''' + if not configs: + self.configs = [Config(dict(), num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = dict() + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + return do_bench(kernel_call) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple([args[i] for i in self.key_idx]) + if key not in self.cache: + # prune configs + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + + + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + .type values: dict[str, Callable[[list[Any]], Any]] + """ + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py new file mode 100644 index 000000000..025f268ac --- /dev/null +++ b/python/triton/runtime/jit.py @@ -0,0 +1,415 @@ +from __future__ import annotations, division + +import ast +import functools +import hashlib +import inspect +import os +import subprocess +import textwrap +from collections import namedtuple + +import torch + +import triton + +try: + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +except ImportError: + get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + """ + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if inspect.isbuiltin(func): + return + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, JITFunction) + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) + finder.visit(tree) + func.hash = finder.ret + self.ret = (self.ret + func.hash).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +@functools.lru_cache() +def version_key(): + import pkgutil + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + with open(triton.compiler.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + + +class KernelInterface: + + def __getitem__(self, grid): + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + return launcher + + +class JITFunction(KernelInterface): + + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -2**31 <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**31 <= arg and arg <= 2**32 - 1: + return "u32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return 'fp32' + elif arg is None: + return None + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return (arg.data_ptr() % JITFunction.divisibility == 0) + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) + + @staticmethod + def _type_of(key): + if isinstance(key, (torch.dtype, triton.language.dtype)): + ty = { + torch.bool: 'i1', + torch.float16: 'fp16', + torch.bfloat16: 'bf16', + torch.float32: 'fp32', + torch.float64: 'fp64', + torch.uint8: 'u8', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + + triton.language.uint8: 'u8', + triton.language.uint16: 'u16', + triton.language.uint32: 'u32', + triton.language.uint64: 'u64', + triton.language.float8: 'fp8', + }[key] + return f'*{ty}' + if key is None: + return '*i8' + assert isinstance(key, str) + return key + + def _make_signature(self, sig_key): + signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) + return signature + + def _make_constants(self, constexpr_key): + constants = {i: k for i, k in zip(self.constexprs, constexpr_key)} + return constants + + def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + if JITFunction.cache_hook is None: + return False + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + key = str(key) + + class LegacyCompiler: + def __init__(self, module, name): + self.module = module + self.name = name + pass + + kwargs = dict(signature=signature, device=device, constants=constants, + num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, + configs=configs) + + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + + def _make_launcher(self): + regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] + constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] + args = ', '.join(regular_args) + # cache key for regular argument type + sig_keys = ', '.join([f'_key_of({arg})' for arg in regular_args]) + # cache key for constexpr argument values + constexpr_keys = ', '.join(constexpr_args) + # cache key for argument specialization + specializations = [] + for i, arg in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") ' + f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) ' + f'else (False,)'] + spec_keys = ', '.join(specializations) + grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + + src = f""" +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): + sig_key = {sig_keys}, + constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} + spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} + key = (version_key, sig_key, constexpr_key, spec_key) + if not extern_libs is None: + key = (key, tuple(extern_libs.items())) + assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + if callable(grid): + grid = grid({{{grid_args}}}) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + device = torch.cuda.current_device() + torch.cuda.set_device(device) + if stream is None: + stream = get_cuda_stream(device) + try: + bin = cache[key] + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + return bin + # kernel not cached -- compile + except KeyError: + # build dict of constant values + args = [{args}] + configs = self._get_config(*args), + constants = self._make_constants(constexpr_key) + constants.update({{i: None for i, arg in enumerate(args) if arg is None}}) + constants.update({{i: 1 for i in configs[0].equal_to_1}}) + # build kernel signature -- doesn't include specialized arguments + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + device = 0 + if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + self.cache[key] = bin + return bin + return None +""" + scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream, + "self": self, "_spec_of": self._spec_of, "_key_of": self._key_of, + "cache": self.cache, "triton": triton, "torch": torch} + exec(src, scope) + return scope[self.fn.__name__] + + def __init__(self, fn, version=None, do_not_specialize=None): + self.fn = fn + self.module = fn.__module__ + self.version = version + # function signature information + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()]) + # specialization hints + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]) + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def"):] + # cache of just-in-time compiled kernels + self.cache = dict() + self.hash = None + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel_decorators = [] + self.kernel = None + # annotations + self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} + self.__annotations__ = fn.__annotations__ + # index of constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] + # launcher + self.run = self._make_launcher() + # re-use docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + @functools.lru_cache() + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + version_key() + return self.hash + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + # - when kernel decorators change, cached kernel + # needs to be cleared + if name == 'kernel_decorators': + self.kernel = None + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == 'src': + self.hash = None + JITFunction.cache_key.fget.cache_clear() + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +def jit(*args, **kwargs): + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * objects within the triton.language package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + if args: + assert len(args) == 1 + assert callable(args[0]) + return JITFunction(args[0], **kwargs) + else: + def decorator(fn): + return JITFunction(fn, **kwargs) + return decorator + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/testing.py b/python/triton/testing.py index 594edcbf2..2c9ece2fe 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -7,7 +7,7 @@ from contextlib import contextmanager import torch import triton._C.libtriton.triton as _triton -from .code_gen import OutOfResources +from .compiler import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass diff --git a/python/triton/utils.py b/python/triton/utils.py new file mode 100644 index 000000000..f446dd06a --- /dev/null +++ b/python/triton/utils.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import torch + + +def cdiv(x, y): + return (x + y - 1) // y + + +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') From e647402fd3e818b5b6c45e14088ac5d357867390 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 18 Sep 2022 12:57:32 -0700 Subject: [PATCH 177/215] Fix warning in generated C code (#667) --- python/triton/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 98ccc6b1a..ac65fd49d 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1139,7 +1139,7 @@ static struct PyModuleDef ModuleDef = {{ ModuleMethods }}; -PyMODINIT_FUNC PyInit_{kernel_name}() {{ +PyMODINIT_FUNC PyInit_{kernel_name}(void) {{ PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; From 00f4ef6958d991479100a4672ca5fd7cbb653da2 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 18 Sep 2022 13:26:42 -0700 Subject: [PATCH 178/215] [CI] wheel/docs workflows now only run on V100 machine --- .github/workflows/documentation.yml | 2 +- .github/workflows/wheels.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 7dfb0a489..28cb20e37 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -8,7 +8,7 @@ jobs: Build-Documentation: - runs-on: self-hosted + runs-on: [self-hosted, V100] steps: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index db682f33f..d627888c5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -8,7 +8,7 @@ jobs: Build-Wheels: - runs-on: self-hosted + runs-on: [self-hosted, V100] steps: From 49f6bc3f2b128ed899c875c0b98bab5c982b8267 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 18 Sep 2022 14:26:29 -0700 Subject: [PATCH 179/215] [FRONTEND] Fix filename too long error in new runtime (#669) --- python/triton/compiler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index ac65fd49d..b49d21b99 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -969,6 +969,13 @@ def _build(name, src, path): return so +def binary_name_to_header_name(name): + if len(name) > 128: + # avoid filename too long errors (filename limit is 255) + name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest() + return f"{name}.h" + + def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, tmpdir): headers = dict() @@ -981,7 +988,7 @@ def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, const char* {name}_ptx = R"({bin["ptx"]})"; unsigned char {name}_bin[] = {{ {','.join(map(hex, bin["cubin"]))} }}; unsigned int {name}_shmem = {shmem_size};""" - headers[name] = os.path.join(tmpdir, f"{name}.h") + headers[name] = os.path.join(tmpdir, binary_name_to_header_name(name)) with open(headers[name], "w") as f: f.write(initializer) @@ -1018,7 +1025,7 @@ unsigned int {name}_shmem = {shmem_size};""" # generate glue code src = "" for bin, shmem_size, name in binaries: - src += f"#include \"{name}.h\"\n" + src += f"#include \"{headers[name]}\"\n" src += f""" #include \"cuda.h\" #include From 2baf333d44cc3802b5f6f3756b3324fcf62d787b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 18 Sep 2022 17:13:12 -0700 Subject: [PATCH 180/215] [DOCS] Fixed typos (#670) --- docs/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 4d62c5650..8a6fabce7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -45,7 +45,7 @@ def setup(app): def wrapped(obj, **kwargs): import triton - if isinstance(obj, triton.code_gen.JITFunction): + if isinstance(obj, triton.runtime.JITFunction): obj = obj.fn return old(obj) @@ -56,7 +56,7 @@ def setup(app): def documenter(app, obj, parent): import triton - if isinstance(obj, triton.code_gen.JITFunction): + if isinstance(obj, triton.runtime.JITFunction): obj = obj.fn return old_documenter(app, obj, parent) From 82956e5d6bf66fb49ad67efc9b9ec77513c57df4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 18 Sep 2022 17:34:05 -0700 Subject: [PATCH 181/215] [PACKAGING] Added missing package --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index b0e48f251..3940cf050 100644 --- a/python/setup.py +++ b/python/setup.py @@ -135,7 +135,7 @@ setup( author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", - packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], + packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"], install_requires=[ "cmake", "filelock", From 93b1adc53bb4eb680c4638f7c0b51f07ddb61940 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 18 Sep 2022 23:09:34 -0700 Subject: [PATCH 182/215] [FRONTEND] Add .warmup() for triton.jit() (#671) --- python/test/unit/runtime/test_cache.py | 22 ++++++++++++++ python/triton/runtime/autotuner.py | 40 +++++++++++++++++++------- python/triton/runtime/jit.py | 29 +++++++++++++------ python/triton/utils.py | 18 ++++++++++++ 4 files changed, 90 insertions(+), 19 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6fad3af3d..3540208d3 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None: except BaseException: error = True assert error is True + + +def test_jit_warmup_cache() -> None: + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, + tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + assert len(kernel_add.cache) == 0 + kernel_add[(1,)].warmup(torch.float32, torch.float32, torch.float32, 32) + assert len(kernel_add.cache) == 1 + kernel_add[(1,)].warmup(*args) + assert len(kernel_add.cache) == 1 + kernel_add[(1,)](*args) + assert len(kernel_add.cache) == 1 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 2175501b6..8ec16c477 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -68,16 +68,7 @@ class Autotuner(KernelInterface): key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -94,6 +85,35 @@ class Autotuner(KernelInterface): config.pre_hook(self.nargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + class Config: """ diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 025f268ac..4911c327f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -12,6 +12,7 @@ from collections import namedtuple import torch import triton +from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream @@ -101,9 +102,16 @@ class KernelInterface: Hence JITFunction.__getitem__ returns a callable proxy that memorizes the grid. """ - def launcher(*args, **kwargs): - return self.run(*args, grid=grid, **kwargs) - return launcher + class Launcher: + @staticmethod + def __call__(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + + @staticmethod + def warmup(*args, **kwargs): + return self.warmup(*args, grid=grid, **kwargs) + + return Launcher() class JITFunction(KernelInterface): @@ -231,7 +239,7 @@ class JITFunction(KernelInterface): grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) src = f""" -def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): sig_key = {sig_keys}, constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} @@ -247,11 +255,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage grid_2 = grid[2] if grid_size > 2 else 1 device = torch.cuda.current_device() torch.cuda.set_device(device) - if stream is None: + if stream is None and not warmup: stream = get_cuda_stream(device) try: bin = cache[key] - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) return bin # kernel not cached -- compile except KeyError: @@ -271,7 +280,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) self.cache[key] = bin return bin return None @@ -317,7 +327,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.__module__ = fn.__module__ @property - @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -326,6 +335,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.hash = dependencies_finder.ret + version_key() return self.hash + def warmup(self, *args, **kwargs): + return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. @@ -349,7 +361,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # to be reinitialized if name == 'src': self.hash = None - JITFunction.cache_key.fget.cache_clear() def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" diff --git a/python/triton/utils.py b/python/triton/utils.py index f446dd06a..2ac84d06e 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,6 +19,24 @@ def next_power_of_2(n): return n +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + @staticmethod + def wrap_dtype(arg): + if isinstance(arg, torch.dtype): + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + def data_ptr(self): + return 0 # optimistically assumes multiple of 16 + + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype From 48f30550f135f23888310809b6743f68563543c7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 19 Sep 2022 21:01:36 -0700 Subject: [PATCH 183/215] [FRONTEND] Now using raw compiler syscalls when possible (#678) --- python/triton/compiler.py | 142 ++++++++++++++++++++------------------ 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b49d21b99..67c32cef4 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -6,11 +6,13 @@ import functools import hashlib import io import os +import shutil import subprocess import sys import sysconfig import tempfile import warnings +from sysconfig import get_paths from typing import Any, Dict, Set, Tuple, Union import setuptools @@ -917,58 +919,6 @@ def generate_name_initializer(signature): src -@contextlib.contextmanager -def quiet(): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = io.StringIO(), io.StringIO() - try: - yield - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr - - -@functools.lru_cache() -def libcuda_dir(): - loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1] - return os.path.dirname(loc) - - -def _build(name, src, path): - # add framework - extra_compile_args = [] - library_dirs = [libcuda_dir()] - include_dirs = [path, "/usr/local/cuda/include/"] - libraries = ['cuda'] - # extra arguments - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name=name, - language='c++', - sources=[src], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], - extra_link_args=extra_link_args, - library_dirs=library_dirs, - libraries=libraries, - ) - # build extension module - args = ['build_ext'] - args.append('--build-temp=' + path) - args.append('--build-lib=' + path) - args.append('-q') - args = dict( - name=name, - ext_modules=[ext], - script_args=args, - ) - # with quiet(): - setuptools.setup(**args) - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(path, '{name}{suffix}'.format(name=name, suffix=suffix)) - return so - - def binary_name_to_header_name(name): if len(name) > 128: # avoid filename too long errors (filename limit is 255) @@ -1030,7 +980,7 @@ unsigned int {name}_shmem = {shmem_size};""" #include \"cuda.h\" #include -inline void gpuAssert(CUresult code, const char *file, int line) +static inline void gpuAssert(CUresult code, const char *file, int line) {{ if (code != CUDA_SUCCESS) {{ @@ -1048,7 +998,7 @@ inline void gpuAssert(CUresult code, const char *file, int line) static CUmodule module = 0; static CUfunction function = 0; -static void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{ +static inline void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{ CUmodule mod; CUfunction fun; CUDA_CHECK(cuModuleLoadData(&mod, src)); @@ -1070,7 +1020,7 @@ static void init_function(const char* name, const unsigned char* src, size_t n_s function = fun; }} -static void init_module(CUdevice device) {{ +static inline void init_module(CUdevice device) {{ {func_init} }} @@ -1209,16 +1159,72 @@ def make_cache_key(fn, signature, configs, constants, num_warps, num_stages): key = hashlib.md5(key.encode("utf-8")).hexdigest() return key +# utilties for generating and compiling C wrappers -def make_shared_object(fn, constants, signature, num_warps, binaries, tmpdir): - src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir) - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) + +@functools.lru_cache() +def libcuda_dir(): + loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1] + return os.path.dirname(loc) + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir): + cuda_lib_dir = libcuda_dir() + cu_include_dir = "/usr/local/cuda/include" + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + py_include_dir = get_paths()["include"] + ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so]) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + library_dirs = [cuda_lib_dir] + include_dirs = [srcdir, cu_include_dir] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) with quiet(): - bin_path = _build(fn.__name__, src_path, tmpdir) - with open(bin_path, "rb") as f: - return f.read() + setuptools.setup(**args) + return so def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): @@ -1243,10 +1249,14 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i with tempfile.TemporaryDirectory() as tmpdir: all_constants = set(constants.keys()) all_constants.update(configs[0].equal_to_1) - so = make_shared_object(fn, all_constants, signature, num_warps, binaries, tmpdir) + src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(fn.__name__, src_path, tmpdir) + with open(so, "rb") as f: + cache_manager.put(f.read()) - # write shared object to cache - cache_manager.put(so) return CompiledKernel(fn.__name__, cache_manager.bin_path) From 7dc2a70edb386e14074f91ee8eeb1f78f270761c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 20 Sep 2022 16:05:14 -0700 Subject: [PATCH 184/215] Revert "Add .warmup() for triton.jit()" (#682) Reverts openai/triton#671 It seems like for some reason this caused out-of-memory errors on some of our internal workloads. I'm reverting this so that HEAD can be used in production at OpenAI, and I will work on digging into this issue asynchronously. --- python/test/unit/runtime/test_cache.py | 22 -------------- python/triton/runtime/autotuner.py | 40 +++++++------------------- python/triton/runtime/jit.py | 29 ++++++------------- python/triton/utils.py | 18 ------------ 4 files changed, 19 insertions(+), 90 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 3540208d3..6fad3af3d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -150,25 +150,3 @@ def test_constexpr_not_callable() -> None: except BaseException: error = True assert error is True - - -def test_jit_warmup_cache() -> None: - @triton.jit - def kernel_add(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) - - args = [ - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), - 32, - ] - assert len(kernel_add.cache) == 0 - kernel_add[(1,)].warmup(torch.float32, torch.float32, torch.float32, 32) - assert len(kernel_add.cache) == 1 - kernel_add[(1,)].warmup(*args) - assert len(kernel_add.cache) == 1 - kernel_add[(1,)](*args) - assert len(kernel_add.cache) == 1 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 8ec16c477..2175501b6 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -68,7 +68,16 @@ class Autotuner(KernelInterface): key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.prune_configs(kwargs) + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -85,35 +94,6 @@ class Autotuner(KernelInterface): config.pre_hook(self.nargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, - num_warps=config.num_warps) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - class Config: """ diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4911c327f..025f268ac 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -12,7 +12,6 @@ from collections import namedtuple import torch import triton -from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream @@ -102,16 +101,9 @@ class KernelInterface: Hence JITFunction.__getitem__ returns a callable proxy that memorizes the grid. """ - class Launcher: - @staticmethod - def __call__(*args, **kwargs): - return self.run(*args, grid=grid, **kwargs) - - @staticmethod - def warmup(*args, **kwargs): - return self.warmup(*args, grid=grid, **kwargs) - - return Launcher() + def launcher(*args, **kwargs): + return self.run(*args, grid=grid, **kwargs) + return launcher class JITFunction(KernelInterface): @@ -239,7 +231,7 @@ class JITFunction(KernelInterface): grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) src = f""" -def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): sig_key = {sig_keys}, constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} @@ -255,12 +247,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage grid_2 = grid[2] if grid_size > 2 else 1 device = torch.cuda.current_device() torch.cuda.set_device(device) - if stream is None and not warmup: + if stream is None: stream = get_cuda_stream(device) try: bin = cache[key] - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) return bin # kernel not cached -- compile except KeyError: @@ -280,8 +271,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) - if not warmup: - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) self.cache[key] = bin return bin return None @@ -327,6 +317,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.__module__ = fn.__module__ @property + @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -335,9 +326,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.hash = dependencies_finder.ret + version_key() return self.hash - def warmup(self, *args, **kwargs): - return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) - # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. @@ -361,6 +349,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # to be reinitialized if name == 'src': self.hash = None + JITFunction.cache_key.fget.cache_clear() def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" diff --git a/python/triton/utils.py b/python/triton/utils.py index 2ac84d06e..f446dd06a 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,24 +19,6 @@ def next_power_of_2(n): return n -class MockTensor: - """ - Can be used in place of real tensors when calling: - kernel.warmup(MockTensor(torch.float32), ...) - """ - @staticmethod - def wrap_dtype(arg): - if isinstance(arg, torch.dtype): - return MockTensor(arg) - return arg - - def __init__(self, dtype): - self.dtype = dtype - - def data_ptr(self): - return 0 # optimistically assumes multiple of 16 - - class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype From e318185eb4561908327a35861f9dc14f3918fdd3 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 20 Sep 2022 18:09:43 -0700 Subject: [PATCH 185/215] [DOCS] Improved README.md wording (#683) Initial wording dates from a time where nobody knew Triton, and comparing it to CUDA helped differentiate it from other existing DSLs. But nowadays this comparison doesn't make much sense; Triton is its own thing, and some people may even still be more productive in CUDA than Triton -- language preferences are subjective after all. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f4b6ef41c..ed0fc71b1 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ # Triton -This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. +This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment for expressing tensor math workloads that offers high flexibility, developer productivity and end to end performance. The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton! From 6abe813d1c5690d79cbbac6ca9375627c19dd3ee Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 21 Sep 2022 10:20:48 -0700 Subject: [PATCH 186/215] Fix issue breaking cudagraphs (#685) @ngimel figured this one out. The errors we were seeing from cudagraphs capture were coming from `cuStreamGetCtx` which is not allowed while a stream is capturing. It appears the result of `cuStreamGetCtx()` isn't even used, so I believe it can just be removed. --- python/triton/compiler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 67c32cef4..535f323da 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1026,13 +1026,10 @@ static inline void init_module(CUdevice device) {{ void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{ - CUcontext ctx; - CUdevice device; - CUDA_CHECK(cuStreamGetCtx(stream, &ctx)); - CUDA_CHECK(cuCtxGetDevice(&device)); - // TODO: machine may have heterogeneous devices if(function == 0){{ + CUdevice device; + CUDA_CHECK(cuCtxGetDevice(&device)); init_module(device); }} void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; From 677ddae618139014ebd4b0407a767c9b900e81b5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 21 Sep 2022 12:13:20 -0700 Subject: [PATCH 187/215] [FRONTEND] Add warmup for triton.jit() (#684) This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object Co-authored-by: Jason Ansel --- python/test/unit/runtime/test_cache.py | 22 ++++++++++++++ python/triton/runtime/autotuner.py | 40 +++++++++++++++++++------- python/triton/runtime/jit.py | 16 +++++++---- python/triton/utils.py | 18 ++++++++++++ 4 files changed, 80 insertions(+), 16 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6fad3af3d..6d6c0e131 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -150,3 +150,25 @@ def test_constexpr_not_callable() -> None: except BaseException: error = True assert error is True + + +def test_jit_warmup_cache() -> None: + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, + tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + assert len(kernel_add.cache) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + assert len(kernel_add.cache) == 1 + kernel_add.warmup(*args, grid=(1,)) + assert len(kernel_add.cache) == 1 + kernel_add.warmup(*args, grid=(1,)) + assert len(kernel_add.cache) == 1 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 2175501b6..8ec16c477 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -68,16 +68,7 @@ class Autotuner(KernelInterface): key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -94,6 +85,35 @@ class Autotuner(KernelInterface): config.pre_hook(self.nargs) return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + class Config: """ diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 025f268ac..0187a7faa 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -12,6 +12,7 @@ from collections import namedtuple import torch import triton +from triton.utils import MockTensor try: from torch._C import _cuda_getCurrentRawStream as get_cuda_stream @@ -231,7 +232,7 @@ class JITFunction(KernelInterface): grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) src = f""" -def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None): +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False): sig_key = {sig_keys}, constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else tuple()} spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else tuple()} @@ -247,11 +248,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage grid_2 = grid[2] if grid_size > 2 else 1 device = torch.cuda.current_device() torch.cuda.set_device(device) - if stream is None: + if stream is None and not warmup: stream = get_cuda_stream(device) try: bin = cache[key] - bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args}) return bin # kernel not cached -- compile except KeyError: @@ -271,7 +273,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) - bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args) self.cache[key] = bin return bin return None @@ -317,7 +320,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.__module__ = fn.__module__ @property - @functools.lru_cache() def cache_key(self): # TODO : hash should be attribute of `self` if self.hash is None: @@ -326,6 +328,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.hash = dependencies_finder.ret + version_key() return self.hash + def warmup(self, *args, **kwargs): + return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. @@ -349,7 +354,6 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage # to be reinitialized if name == 'src': self.hash = None - JITFunction.cache_key.fget.cache_clear() def __repr__(self): return f"JITFunction({self.module}:{self.fn.__name__})" diff --git a/python/triton/utils.py b/python/triton/utils.py index f446dd06a..2ac84d06e 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -19,6 +19,24 @@ def next_power_of_2(n): return n +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + @staticmethod + def wrap_dtype(arg): + if isinstance(arg, torch.dtype): + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + def data_ptr(self): + return 0 # optimistically assumes multiple of 16 + + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype From df67068bb02894d49ffe6696166fc99548339510 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 21 Sep 2022 20:18:02 -0700 Subject: [PATCH 188/215] [pybind11] Update pybind11 to 2.10.0 (#691) This PR updates the version of pybind11 to 2.10.0 (the latest stable). --- python/src/pybind11/attr.h | 391 ++- python/src/pybind11/buffer_info.h | 167 +- python/src/pybind11/cast.h | 2221 ++++++-------- python/src/pybind11/chrono.h | 193 +- python/src/pybind11/complex.h | 39 +- python/src/pybind11/detail/class.h | 401 ++- python/src/pybind11/detail/common.h | 1070 ++++--- python/src/pybind11/detail/descr.h | 116 +- python/src/pybind11/detail/init.h | 303 +- python/src/pybind11/detail/internals.h | 485 +++- python/src/pybind11/detail/type_caster_base.h | 1010 +++++++ python/src/pybind11/detail/typeid.h | 36 +- python/src/pybind11/eigen.h | 559 ++-- python/src/pybind11/embed.h | 185 +- python/src/pybind11/eval.h | 129 +- python/src/pybind11/functional.h | 90 +- python/src/pybind11/gil.h | 202 ++ python/src/pybind11/iostream.h | 154 +- python/src/pybind11/numpy.h | 1514 ++++++---- python/src/pybind11/operators.h | 253 +- python/src/pybind11/options.h | 45 +- python/src/pybind11/pybind11.h | 2584 +++++++++++------ python/src/pybind11/pytypes.h | 1815 +++++++++--- python/src/pybind11/stl.h | 295 +- python/src/pybind11/stl/filesystem.h | 116 + python/src/pybind11/stl_bind.h | 715 +++-- 26 files changed, 9869 insertions(+), 5219 deletions(-) create mode 100644 python/src/pybind11/detail/type_caster_base.h create mode 100644 python/src/pybind11/gil.h create mode 100644 python/src/pybind11/stl/filesystem.h diff --git a/python/src/pybind11/attr.h b/python/src/pybind11/attr.h index 6962d6fc5..db7cd8eff 100644 --- a/python/src/pybind11/attr.h +++ b/python/src/pybind11/attr.h @@ -10,65 +10,113 @@ #pragma once +#include "detail/common.h" #include "cast.h" -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +#include + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) /// \addtogroup annotations /// @{ /// Annotation for methods -struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; +struct is_method { + handle class_; + explicit is_method(const handle &c) : class_(c) {} +}; /// Annotation for operators -struct is_operator { }; +struct is_operator {}; + +/// Annotation for classes that cannot be subclassed +struct is_final {}; /// Annotation for parent scope -struct scope { handle value; scope(const handle &s) : value(s) { } }; +struct scope { + handle value; + explicit scope(const handle &s) : value(s) {} +}; /// Annotation for documentation -struct doc { const char *value; doc(const char *value) : value(value) { } }; +struct doc { + const char *value; + explicit doc(const char *value) : value(value) {} +}; /// Annotation for function names -struct name { const char *value; name(const char *value) : value(value) { } }; +struct name { + const char *value; + explicit name(const char *value) : value(value) {} +}; /// Annotation indicating that a function is an overload associated with a given "sibling" -struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; +struct sibling { + handle value; + explicit sibling(const handle &value) : value(value.ptr()) {} +}; /// Annotation indicating that a class derives from another given type -template struct base { - PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") - base() { } +template +struct base { + + PYBIND11_DEPRECATED( + "base() was deprecated in favor of specifying 'T' as a template argument to class_") + base() = default; }; /// Keep patient alive while nurse lives -template struct keep_alive { }; +template +struct keep_alive {}; /// Annotation indicating that a class is involved in a multiple inheritance relationship -struct multiple_inheritance { }; +struct multiple_inheritance {}; /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class -struct dynamic_attr { }; +struct dynamic_attr {}; /// Annotation which enables the buffer protocol for a type -struct buffer_protocol { }; +struct buffer_protocol {}; /// Annotation which requests that a special metaclass is created for a type struct metaclass { handle value; PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") - metaclass() {} + metaclass() = default; /// Override pybind11's default metaclass - explicit metaclass(handle value) : value(value) { } + explicit metaclass(handle value) : value(value) {} +}; + +/// Specifies a custom callback with signature `void (PyHeapTypeObject*)` that +/// may be used to customize the Python type. +/// +/// The callback is invoked immediately before `PyType_Ready`. +/// +/// Note: This is an advanced interface, and uses of it may require changes to +/// work with later versions of pybind11. You may wish to consult the +/// implementation of `make_new_python_type` in `detail/classes.h` to understand +/// the context in which the callback will be run. +struct custom_type_setup { + using callback = std::function; + + explicit custom_type_setup(callback value) : value(std::move(value)) {} + + callback value; }; /// Annotation that marks a class as local to the module: -struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; +struct module_local { + const bool value; + constexpr explicit module_local(bool v = true) : value(v) {} +}; /// Annotation to mark enums as an arithmetic type -struct arithmetic { }; +struct arithmetic {}; + +/// Mark a function for addition at the beginning of the existing overload chain instead of the end +struct prepend {}; /** \rst A call policy which places one or more guard variables (``Ts...``) around the function call. @@ -88,9 +136,13 @@ struct arithmetic { }; return foo(args...); // forwarded arguments }); \endrst */ -template struct call_guard; +template +struct call_guard; -template <> struct call_guard<> { using type = detail::void_type; }; +template <> +struct call_guard<> { + using type = detail::void_type; +}; template struct call_guard { @@ -110,13 +162,14 @@ struct call_guard { /// @} annotations -NAMESPACE_BEGIN(detail) +PYBIND11_NAMESPACE_BEGIN(detail) /* Forward declarations */ enum op_id : int; enum op_type : int; struct undefined_t; -template struct op_; -inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); +template +struct op_; +void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); /// Internal data structure which holds metadata about a keyword argument struct argument_record { @@ -127,14 +180,16 @@ struct argument_record { bool none : 1; ///< True if None is allowed when loading argument_record(const char *name, const char *descr, handle value, bool convert, bool none) - : name(name), descr(descr), value(value), convert(convert), none(none) { } + : name(name), descr(descr), value(value), convert(convert), none(none) {} }; -/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) +/// Internal data structure which holds metadata about a bound function (signature, overloads, +/// etc.) struct function_record { function_record() : is_constructor(false), is_new_style_constructor(false), is_stateless(false), - is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } + is_operator(false), is_method(false), has_args(false), has_kwargs(false), + prepend(false) {} /// Function name char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ @@ -149,13 +204,13 @@ struct function_record { std::vector args; /// Pointer to lambda function which converts arguments and performs the actual call - handle (*impl) (function_call &) = nullptr; + handle (*impl)(function_call &) = nullptr; /// Storage for the wrapped function pointer and captured data, if any - void *data[3] = { }; + void *data[3] = {}; /// Pointer to custom destructor for 'data' (if needed) - void (*free_data) (function_record *ptr) = nullptr; + void (*free_data)(function_record *ptr) = nullptr; /// Return value policy associated with this function return_value_policy policy = return_value_policy::automatic; @@ -172,18 +227,28 @@ struct function_record { /// True if this is an operator (__add__), etc. bool is_operator : 1; + /// True if this is a method + bool is_method : 1; + /// True if the function has a '*args' argument bool has_args : 1; /// True if the function has a '**kwargs' argument bool has_kwargs : 1; - /// True if this is a method - bool is_method : 1; + /// True if this function is to be inserted at the beginning of the overload resolution chain + bool prepend : 1; /// Number of arguments (including py::args and/or py::kwargs, if present) std::uint16_t nargs; + /// Number of leading positional arguments, which are terminated by a py::args or py::kwargs + /// argument or by a py::kw_only annotation. + std::uint16_t nargs_pos = 0; + + /// Number of leading arguments (counted in `nargs`) that are positional-only + std::uint16_t nargs_pos_only = 0; + /// Python method object PyMethodDef *def = nullptr; @@ -201,7 +266,7 @@ struct function_record { struct type_record { PYBIND11_NOINLINE type_record() : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), - default_holder(true), module_local(false) { } + default_holder(true), module_local(false), is_final(false) {} /// Handle to the parent scope handle scope; @@ -239,6 +304,9 @@ struct type_record { /// Custom metaclass (optional) handle metaclass; + /// Custom type setup. + custom_type_setup::callback custom_type_setup_callback; + /// Multiple inheritance marker bool multiple_inheritance : 1; @@ -254,42 +322,48 @@ struct type_record { /// Is the class definition local to the module shared object? bool module_local : 1; - PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { - auto base_info = detail::get_type_info(base, false); + /// Is the class inheritable from python classes? + bool is_final : 1; + + PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *) ) { + auto *base_info = detail::get_type_info(base, false); if (!base_info) { std::string tname(base.name()); detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + - "\" referenced unknown base type \"" + tname + "\""); + pybind11_fail("generic_type: type \"" + std::string(name) + + "\" referenced unknown base type \"" + tname + "\""); } if (default_holder != base_info->default_holder) { std::string tname(base.name()); detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + - (default_holder ? "does not have" : "has") + - " a non-default holder type while its base \"" + tname + "\" " + - (base_info->default_holder ? "does not" : "does")); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (default_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + "\" " + + (base_info->default_holder ? "does not" : "does")); } bases.append((PyObject *) base_info->type); - if (base_info->type->tp_dictoffset != 0) - dynamic_attr = true; +#if PY_VERSION_HEX < 0x030B0000 + dynamic_attr |= base_info->type->tp_dictoffset != 0; +#else + dynamic_attr |= (base_info->type->tp_flags & Py_TPFLAGS_MANAGED_DICT) != 0; +#endif - if (caster) + if (caster) { base_info->implicit_casts.emplace_back(type, caster); + } } }; -inline function_call::function_call(const function_record &f, handle p) : - func(f), parent(p) { +inline function_call::function_call(const function_record &f, handle p) : func(f), parent(p) { args.reserve(f.nargs); args_convert.reserve(f.nargs); } /// Tag for a new-style `__init__` defined in `detail/init.h` -struct is_new_style_constructor { }; +struct is_new_style_constructor {}; /** * Partial template specializations to process custom attributes provided to @@ -297,105 +371,177 @@ struct is_new_style_constructor { }; * fields in the type_record and function_record data structures or executed at * runtime to deal with custom call policies (e.g. keep_alive). */ -template struct process_attribute; +template +struct process_attribute; -template struct process_attribute_default { +template +struct process_attribute_default { /// Default implementation: do nothing - static void init(const T &, function_record *) { } - static void init(const T &, type_record *) { } - static void precall(function_call &) { } - static void postcall(function_call &, handle) { } + static void init(const T &, function_record *) {} + static void init(const T &, type_record *) {} + static void precall(function_call &) {} + static void postcall(function_call &, handle) {} }; /// Process an attribute specifying the function's name -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } }; /// Process an attribute specifying the function's docstring -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } }; /// Process an attribute specifying the function's docstring (provided as a C-style string) -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const char *d, function_record *r) { r->doc = const_cast(d); } static void init(const char *d, type_record *r) { r->doc = const_cast(d); } }; -template <> struct process_attribute : process_attribute { }; +template <> +struct process_attribute : process_attribute {}; /// Process an attribute indicating the function's return value policy -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const return_value_policy &p, function_record *r) { r->policy = p; } }; -/// Process an attribute which indicates that this is an overloaded function associated with a given sibling -template <> struct process_attribute : process_attribute_default { +/// Process an attribute which indicates that this is an overloaded function associated with a +/// given sibling +template <> +struct process_attribute : process_attribute_default { static void init(const sibling &s, function_record *r) { r->sibling = s.value; } }; /// Process an attribute which indicates that this function is a method -template <> struct process_attribute : process_attribute_default { - static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } +template <> +struct process_attribute : process_attribute_default { + static void init(const is_method &s, function_record *r) { + r->is_method = true; + r->scope = s.class_; + } }; /// Process an attribute which indicates the parent scope of a method -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const scope &s, function_record *r) { r->scope = s.value; } }; /// Process an attribute which indicates that this function is an operator -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const is_operator &, function_record *r) { r->is_operator = true; } }; -template <> struct process_attribute : process_attribute_default { - static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } +template <> +struct process_attribute + : process_attribute_default { + static void init(const is_new_style_constructor &, function_record *r) { + r->is_new_style_constructor = true; + } }; +inline void check_kw_only_arg(const arg &a, function_record *r) { + if (r->args.size() > r->nargs_pos && (!a.name || a.name[0] == '\0')) { + pybind11_fail("arg(): cannot specify an unnamed argument after a kw_only() annotation or " + "args() argument"); + } +} + +inline void append_self_arg_if_needed(function_record *r) { + if (r->is_method && r->args.empty()) { + r->args.emplace_back("self", nullptr, handle(), /*convert=*/true, /*none=*/false); + } +} + /// Process a keyword argument attribute (*without* a default value) -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const arg &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); + append_self_arg_if_needed(r); r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); + + check_kw_only_arg(a, r); } }; /// Process a keyword argument attribute (*with* a default value) -template <> struct process_attribute : process_attribute_default { +template <> +struct process_attribute : process_attribute_default { static void init(const arg_v &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); + if (r->is_method && r->args.empty()) { + r->args.emplace_back( + "self", /*descr=*/nullptr, /*parent=*/handle(), /*convert=*/true, /*none=*/false); + } if (!a.value) { -#if !defined(NDEBUG) +#if defined(PYBIND11_DETAILED_ERROR_MESSAGES) std::string descr("'"); - if (a.name) descr += std::string(a.name) + ": "; + if (a.name) { + descr += std::string(a.name) + ": "; + } descr += a.type + "'"; if (r->is_method) { - if (r->name) - descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; - else + if (r->name) { + descr += " in method '" + (std::string) str(r->scope) + "." + + (std::string) r->name + "'"; + } else { descr += " in method of '" + (std::string) str(r->scope) + "'"; + } } else if (r->name) { descr += " in function '" + (std::string) r->name + "'"; } - pybind11_fail("arg(): could not convert default argument " - + descr + " into a Python object (type not registered yet?)"); + pybind11_fail("arg(): could not convert default argument " + descr + + " into a Python object (type not registered yet?)"); #else pybind11_fail("arg(): could not convert default argument " "into a Python object (type not registered yet?). " - "Compile in debug mode for more information."); + "#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for " + "more information."); #endif } r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); + + check_kw_only_arg(a, r); } }; -/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) +/// Process a keyword-only-arguments-follow pseudo argument +template <> +struct process_attribute : process_attribute_default { + static void init(const kw_only &, function_record *r) { + append_self_arg_if_needed(r); + if (r->has_args && r->nargs_pos != static_cast(r->args.size())) { + pybind11_fail("Mismatched args() and kw_only(): they must occur at the same relative " + "argument location (or omit kw_only() entirely)"); + } + r->nargs_pos = static_cast(r->args.size()); + } +}; + +/// Process a positional-only-argument maker +template <> +struct process_attribute : process_attribute_default { + static void init(const pos_only &, function_record *r) { + append_self_arg_if_needed(r); + r->nargs_pos_only = static_cast(r->args.size()); + if (r->nargs_pos_only > r->nargs_pos) { + pybind11_fail("pos_only(): cannot follow a py::args() argument"); + } + // It also can't follow a kw_only, but a static_assert in pybind11.h checks that + } +}; + +/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees +/// that) template -struct process_attribute::value>> : process_attribute_default { +struct process_attribute::value>> + : process_attribute_default { static void init(const handle &h, type_record *r) { r->bases.append(h); } }; @@ -408,7 +554,9 @@ struct process_attribute> : process_attribute_default> { /// Process a multiple inheritance attribute template <> struct process_attribute : process_attribute_default { - static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } + static void init(const multiple_inheritance &, type_record *r) { + r->multiple_inheritance = true; + } }; template <> @@ -416,6 +564,18 @@ struct process_attribute : process_attribute_default static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } }; +template <> +struct process_attribute { + static void init(const custom_type_setup &value, type_record *r) { + r->custom_type_setup_callback = value.value; + } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const is_final &, type_record *r) { r->is_final = true; } +}; + template <> struct process_attribute : process_attribute_default { static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } @@ -431,46 +591,70 @@ struct process_attribute : process_attribute_default static void init(const module_local &l, type_record *r) { r->module_local = l.value; } }; +/// Process a 'prepend' attribute, putting this at the beginning of the overload chain +template <> +struct process_attribute : process_attribute_default { + static void init(const prepend &, function_record *r) { r->prepend = true; } +}; + /// Process an 'arithmetic' attribute for enums (does nothing here) template <> struct process_attribute : process_attribute_default {}; template -struct process_attribute> : process_attribute_default> { }; +struct process_attribute> : process_attribute_default> {}; /** * Process a keep_alive call policy -- invokes keep_alive_impl during the * pre-call handler if both Nurse, Patient != 0 and use the post-call handler * otherwise */ -template struct process_attribute> : public process_attribute_default> { +template +struct process_attribute> + : public process_attribute_default> { template = 0> - static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } + static void precall(function_call &call) { + keep_alive_impl(Nurse, Patient, call, handle()); + } template = 0> - static void postcall(function_call &, handle) { } + static void postcall(function_call &, handle) {} template = 0> - static void precall(function_call &) { } + static void precall(function_call &) {} template = 0> - static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } + static void postcall(function_call &call, handle ret) { + keep_alive_impl(Nurse, Patient, call, ret); + } }; /// Recursively iterate over variadic template arguments -template struct process_attributes { - static void init(const Args&... args, function_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); +template +struct process_attributes { + static void init(const Args &...args, function_record *r) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{ + 0, ((void) process_attribute::type>::init(args, r), 0)...}; } - static void init(const Args&... args, type_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); + static void init(const Args &...args, type_record *r) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::init(args, r), 0)...}; } static void precall(function_call &call) { - int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::precall(call), 0)...}; } static void postcall(function_call &call, handle fn_ret) { - int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call, fn_ret); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(fn_ret); + using expander = int[]; + (void) expander{ + 0, (process_attribute::type>::postcall(call, fn_ret), 0)...}; } }; @@ -484,10 +668,11 @@ using extract_guard_t = typename exactly_one_t, Extr /// Check the number of named arguments at compile time template ::value...), - size_t self = constexpr_sum(std::is_same::value...)> + size_t self = constexpr_sum(std::is_same::value...)> constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { - return named == 0 || (self + named + has_args + has_kwargs) == nargs; + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(nargs, has_args, has_kwargs); + return named == 0 || (self + named + size_t(has_args) + size_t(has_kwargs)) == nargs; } -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/python/src/pybind11/buffer_info.h b/python/src/pybind11/buffer_info.h index 9f072fa73..06120d556 100644 --- a/python/src/pybind11/buffer_info.h +++ b/python/src/pybind11/buffer_info.h @@ -11,56 +11,122 @@ #include "detail/common.h" -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +PYBIND11_NAMESPACE_BEGIN(detail) + +// Default, C-style strides +inline std::vector c_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + if (ndim > 0) { + for (size_t i = ndim - 1; i > 0; --i) { + strides[i - 1] = strides[i] * shape[i]; + } + } + return strides; +} + +// F-style strides; default when constructing an array_t with `ExtraFlags & f_style` +inline std::vector f_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = 1; i < ndim; ++i) { + strides[i] = strides[i - 1] * shape[i - 1]; + } + return strides; +} + +PYBIND11_NAMESPACE_END(detail) /// Information record describing a Python buffer object struct buffer_info { void *ptr = nullptr; // Pointer to the underlying storage ssize_t itemsize = 0; // Size of individual items in bytes ssize_t size = 0; // Total number of entries - std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() + std::string format; // For homogeneous buffers, this should be set to + // format_descriptor::format() ssize_t ndim = 0; // Number of dimensions std::vector shape; // Shape of the tensor (1 entry per dimension) - std::vector strides; // Number of entries between adjacent entries (for each per dimension) + std::vector strides; // Number of bytes between adjacent entries + // (for each per dimension) + bool readonly = false; // flag to indicate if the underlying storage may be written to - buffer_info() { } + buffer_info() = default; - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container shape_in, detail::any_container strides_in) - : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), - shape(std::move(shape_in)), strides(std::move(strides_in)) { - if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) + buffer_info(void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t ndim, + detail::any_container shape_in, + detail::any_container strides_in, + bool readonly = false) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), readonly(readonly) { + if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) { pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); - for (size_t i = 0; i < (size_t) ndim; ++i) + } + for (size_t i = 0; i < (size_t) ndim; ++i) { size *= shape[i]; + } } template - buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) - : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } + buffer_info(T *ptr, + detail::any_container shape_in, + detail::any_container strides_in, + bool readonly = false) + : buffer_info(private_ctr_tag(), + ptr, + sizeof(T), + format_descriptor::format(), + static_cast(shape_in->size()), + std::move(shape_in), + std::move(strides_in), + readonly) {} - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) - : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } + buffer_info(void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t size, + bool readonly = false) + : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}, readonly) {} template - buffer_info(T *ptr, ssize_t size) - : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } + buffer_info(T *ptr, ssize_t size, bool readonly = false) + : buffer_info(ptr, sizeof(T), format_descriptor::format(), size, readonly) {} + + template + buffer_info(const T *ptr, ssize_t size, bool readonly = true) + : buffer_info( + const_cast(ptr), sizeof(T), format_descriptor::format(), size, readonly) {} explicit buffer_info(Py_buffer *view, bool ownview = true) - : buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { - this->view = view; + : buffer_info( + view->buf, + view->itemsize, + view->format, + view->ndim, + {view->shape, view->shape + view->ndim}, + /* Though buffer::request() requests PyBUF_STRIDES, ctypes objects + * ignore this flag and return a view with NULL strides. + * When strides are NULL, build them manually. */ + view->strides + ? std::vector(view->strides, view->strides + view->ndim) + : detail::c_strides({view->shape, view->shape + view->ndim}, view->itemsize), + (view->readonly != 0)) { + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + this->m_view = view; + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) this->ownview = ownview; } buffer_info(const buffer_info &) = delete; - buffer_info& operator=(const buffer_info &) = delete; + buffer_info &operator=(const buffer_info &) = delete; - buffer_info(buffer_info &&other) { - (*this) = std::move(other); - } + buffer_info(buffer_info &&other) noexcept { (*this) = std::move(other); } - buffer_info& operator=(buffer_info &&rhs) { + buffer_info &operator=(buffer_info &&rhs) noexcept { ptr = rhs.ptr; itemsize = rhs.itemsize; size = rhs.size; @@ -68,41 +134,60 @@ struct buffer_info { ndim = rhs.ndim; shape = std::move(rhs.shape); strides = std::move(rhs.strides); - std::swap(view, rhs.view); + std::swap(m_view, rhs.m_view); std::swap(ownview, rhs.ownview); + readonly = rhs.readonly; return *this; } ~buffer_info() { - if (view && ownview) { PyBuffer_Release(view); delete view; } + if (m_view && ownview) { + PyBuffer_Release(m_view); + delete m_view; + } } + Py_buffer *view() const { return m_view; } + Py_buffer *&view() { return m_view; } + private: - struct private_ctr_tag { }; + struct private_ctr_tag {}; - buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container &&shape_in, detail::any_container &&strides_in) - : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } + buffer_info(private_ctr_tag, + void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t ndim, + detail::any_container &&shape_in, + detail::any_container &&strides_in, + bool readonly) + : buffer_info( + ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in), readonly) {} - Py_buffer *view = nullptr; + Py_buffer *m_view = nullptr; bool ownview = false; }; -NAMESPACE_BEGIN(detail) +PYBIND11_NAMESPACE_BEGIN(detail) -template struct compare_buffer_info { - static bool compare(const buffer_info& b) { +template +struct compare_buffer_info { + static bool compare(const buffer_info &b) { return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); } }; -template struct compare_buffer_info::value>> { - static bool compare(const buffer_info& b) { - return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || - ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || - ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info &b) { + return (size_t) b.itemsize == sizeof(T) + && (b.format == format_descriptor::value + || ((sizeof(T) == sizeof(long)) + && b.format == (std::is_unsigned::value ? "L" : "l")) + || ((sizeof(T) == sizeof(size_t)) + && b.format == (std::is_unsigned::value ? "N" : "n"))); } }; -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/python/src/pybind11/cast.h b/python/src/pybind11/cast.h index 8d0fd5d90..a0e32281b 100644 --- a/python/src/pybind11/cast.h +++ b/python/src/pybind11/cast.h @@ -10,1006 +10,183 @@ #pragma once -#include "pytypes.h" -#include "detail/typeid.h" +#include "detail/common.h" #include "detail/descr.h" -#include "detail/internals.h" +#include "detail/type_caster_base.h" +#include "detail/typeid.h" +#include "pytypes.h" + #include -#include +#include +#include +#include +#include +#include +#include #include #include +#include +#include -#if defined(PYBIND11_CPP17) -# if defined(__has_include) -# if __has_include() -# define PYBIND11_HAS_STRING_VIEW -# endif -# elif defined(_MSC_VER) -# define PYBIND11_HAS_STRING_VIEW -# endif -#endif -#ifdef PYBIND11_HAS_STRING_VIEW -#include -#endif +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/// A life support system for temporary objects created by `type_caster::load()`. -/// Adding a patient will keep it alive up until the enclosing function returns. -class loader_life_support { -public: - /// A new patient frame is created when a function is entered - loader_life_support() { - get_internals().loader_patient_stack.push_back(nullptr); - } - - /// ... and destroyed after it returns - ~loader_life_support() { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - pybind11_fail("loader_life_support: internal error"); - - auto ptr = stack.back(); - stack.pop_back(); - Py_CLEAR(ptr); - - // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) - if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2) - stack.shrink_to_fit(); - } - - /// This can only be used inside a pybind11-bound function, either by `argument_loader` - /// at argument preparation time or by `py::cast()` at execution time. - PYBIND11_NOINLINE static void add_patient(handle h) { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - throw cast_error("When called outside a bound function, py::cast() cannot " - "do Python -> C++ conversions which require the creation " - "of temporary values"); - - auto &list_ptr = stack.back(); - if (list_ptr == nullptr) { - list_ptr = PyList_New(1); - if (!list_ptr) - pybind11_fail("loader_life_support: error allocating list"); - PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); - } else { - auto result = PyList_Append(list_ptr, h.ptr()); - if (result == -1) - pybind11_fail("loader_life_support: error adding patient"); - } - } -}; - -// Gets the cache entry for the given type, creating it if necessary. The return value is the pair -// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was -// just created. -inline std::pair all_type_info_get_cache(PyTypeObject *type); - -// Populates a just-created cache entry. -PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { - std::vector check; - for (handle parent : reinterpret_borrow(t->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - - auto const &type_dict = get_internals().registered_types_py; - for (size_t i = 0; i < check.size(); i++) { - auto type = check[i]; - // Ignore Python2 old-style class super type: - if (!PyType_Check((PyObject *) type)) continue; - - // Check `type` in the current set of registered python types: - auto it = type_dict.find(type); - if (it != type_dict.end()) { - // We found a cache entry for it, so it's either pybind-registered or has pre-computed - // pybind bases, but we have to make sure we haven't already seen the type(s) before: we - // want to follow Python/virtual C++ rules that there should only be one instance of a - // common base. - for (auto *tinfo : it->second) { - // NB: Could use a second set here, rather than doing a linear search, but since - // having a large number of immediate pybind11-registered types seems fairly - // unlikely, that probably isn't worthwhile. - bool found = false; - for (auto *known : bases) { - if (known == tinfo) { found = true; break; } - } - if (!found) bases.push_back(tinfo); - } - } - else if (type->tp_bases) { - // It's some python type, so keep follow its bases classes to look for one or more - // registered types - if (i + 1 == check.size()) { - // When we're at the end, we can pop off the current element to avoid growing - // `check` when adding just one base (which is typical--i.e. when there is no - // multiple inheritance) - check.pop_back(); - i--; - } - for (handle parent : reinterpret_borrow(type->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - } - } -} - -/** - * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will - * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side - * derived class that uses single inheritance. Will contain as many types as required for a Python - * class that uses multiple inheritance to inherit (directly or indirectly) from multiple - * pybind-registered classes. Will be empty if neither the type nor any base classes are - * pybind-registered. - * - * The value is cached for the lifetime of the Python type. - */ -inline const std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - - return ins.first->second; -} - -/** - * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any - * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use - * `all_type_info` instead if you want to support multiple bases. - */ -PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { - auto &bases = all_type_info(type); - if (bases.size() == 0) - return nullptr; - if (bases.size() > 1) - pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); - return bases.front(); -} - -inline detail::type_info *get_local_type_info(const std::type_index &tp) { - auto &locals = registered_local_types_cpp(); - auto it = locals.find(tp); - if (it != locals.end()) - return it->second; - return nullptr; -} - -inline detail::type_info *get_global_type_info(const std::type_index &tp) { - auto &types = get_internals().registered_types_cpp; - auto it = types.find(tp); - if (it != types.end()) - return it->second; - return nullptr; -} - -/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. -PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, - bool throw_if_missing = false) { - if (auto ltype = get_local_type_info(tp)) - return ltype; - if (auto gtype = get_global_type_info(tp)) - return gtype; - - if (throw_if_missing) { - std::string tname = tp.name(); - detail::clean_type_id(tname); - pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); - } - return nullptr; -} - -PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { - detail::type_info *type_info = get_type_info(tp, throw_if_missing); - return handle(type_info ? ((PyObject *) type_info->type) : nullptr); -} - -struct value_and_holder { - instance *inst = nullptr; - size_t index = 0u; - const detail::type_info *type = nullptr; - void **vh = nullptr; - - // Main constructor for a found value/holder: - value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : - inst{i}, index{index}, type{type}, - vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} - {} - - // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) - value_and_holder() {} - - // Used for past-the-end iterator - value_and_holder(size_t index) : index{index} {} - - template V *&value_ptr() const { - return reinterpret_cast(vh[0]); - } - // True if this `value_and_holder` has a non-null value pointer - explicit operator bool() const { return value_ptr(); } - - template H &holder() const { - return reinterpret_cast(vh[1]); - } - bool holder_constructed() const { - return inst->simple_layout - ? inst->simple_holder_constructed - : inst->nonsimple.status[index] & instance::status_holder_constructed; - } - void set_holder_constructed(bool v = true) { - if (inst->simple_layout) - inst->simple_holder_constructed = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_holder_constructed; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; - } - bool instance_registered() const { - return inst->simple_layout - ? inst->simple_instance_registered - : inst->nonsimple.status[index] & instance::status_instance_registered; - } - void set_instance_registered(bool v = true) { - if (inst->simple_layout) - inst->simple_instance_registered = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_instance_registered; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; - } -}; - -// Container for accessing and iterating over an instance's values/holders -struct values_and_holders { -private: - instance *inst; - using type_vec = std::vector; - const type_vec &tinfo; - -public: - values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} - - struct iterator { - private: - instance *inst = nullptr; - const type_vec *types = nullptr; - value_and_holder curr; - friend struct values_and_holders; - iterator(instance *inst, const type_vec *tinfo) - : inst{inst}, types{tinfo}, - curr(inst /* instance */, - types->empty() ? nullptr : (*types)[0] /* type info */, - 0, /* vpos: (non-simple types only): the first vptr comes first */ - 0 /* index */) - {} - // Past-the-end iterator: - iterator(size_t end) : curr(end) {} - public: - bool operator==(const iterator &other) { return curr.index == other.curr.index; } - bool operator!=(const iterator &other) { return curr.index != other.curr.index; } - iterator &operator++() { - if (!inst->simple_layout) - curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; - ++curr.index; - curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; - return *this; - } - value_and_holder &operator*() { return curr; } - value_and_holder *operator->() { return &curr; } - }; - - iterator begin() { return iterator(inst, &tinfo); } - iterator end() { return iterator(tinfo.size()); } - - iterator find(const type_info *find_type) { - auto it = begin(), endit = end(); - while (it != endit && it->type != find_type) ++it; - return it; - } - - size_t size() { return tinfo.size(); } -}; - -/** - * Extracts C++ value and holder pointer references from an instance (which may contain multiple - * values/holders for python-side multiple inheritance) that match the given type. Throws an error - * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If - * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, - * regardless of type (and the resulting .type will be nullptr). - * - * The returned object should be short-lived: in particular, it must not outlive the called-upon - * instance. - */ -PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { - // Optimize common case: - if (!find_type || Py_TYPE(this) == find_type->type) - return value_and_holder(this, find_type, 0, 0); - - detail::values_and_holders vhs(this); - auto it = vhs.find(find_type); - if (it != vhs.end()) - return *it; - - if (!throw_if_missing) - return value_and_holder(); - -#if defined(NDEBUG) - pybind11_fail("pybind11::detail::instance::get_value_and_holder: " - "type is not a pybind11 base of the given instance " - "(compile in debug mode for type details)"); -#else - pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + - std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + - std::string(Py_TYPE(this)->tp_name) + "' instance"); -#endif -} - -PYBIND11_NOINLINE inline void instance::allocate_layout() { - auto &tinfo = all_type_info(Py_TYPE(this)); - - const size_t n_types = tinfo.size(); - - if (n_types == 0) - pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); - - simple_layout = - n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); - - // Simple path: no python-side multiple inheritance, and a small-enough holder - if (simple_layout) { - simple_value_holder[0] = nullptr; - simple_holder_constructed = false; - simple_instance_registered = false; - } - else { // multiple base types or a too-large holder - // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, - // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool - // values that tracks whether each associated holder has been initialized. Each [block] is - // padded, if necessary, to an integer multiple of sizeof(void *). - size_t space = 0; - for (auto t : tinfo) { - space += 1; // value pointer - space += t->holder_size_in_ptrs; // holder instance - } - size_t flags_at = space; - space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) - - // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, - // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 - // they default to using pymalloc, which is designed to be efficient for small allocations - // like the one we're doing here; in earlier versions (and for larger allocations) they are - // just wrappers around malloc. -#if PY_VERSION_HEX >= 0x03050000 - nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); -#else - nonsimple.values_and_holders = (void **) PyMem_New(void *, space); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); - std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); -#endif - nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); - } - owned = true; -} - -PYBIND11_NOINLINE inline void instance::deallocate_layout() { - if (!simple_layout) - PyMem_Free(nonsimple.values_and_holders); -} - -PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { - handle type = detail::get_type_handle(tp, false); - if (!type) - return false; - return isinstance(obj, type); -} - -PYBIND11_NOINLINE inline std::string error_string() { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); - return "Unknown internal error occurred"; - } - - error_scope scope; // Preserve error state - - std::string errorString; - if (scope.type) { - errorString += handle(scope.type).attr("__name__").cast(); - errorString += ": "; - } - if (scope.value) - errorString += (std::string) str(scope.value); - - PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); - -#if PY_MAJOR_VERSION >= 3 - if (scope.trace != nullptr) - PyException_SetTraceback(scope.value, scope.trace); -#endif - -#if !defined(PYPY_VERSION) - if (scope.trace) { - PyTracebackObject *trace = (PyTracebackObject *) scope.trace; - - /* Get the deepest trace possible */ - while (trace->tb_next) - trace = trace->tb_next; - - PyFrameObject *frame = trace->tb_frame; - errorString += "\n\nAt:\n"; - while (frame) { - int lineno = PyFrame_GetLineNumber(frame); - errorString += - " " + handle(frame->f_code->co_filename).cast() + - "(" + std::to_string(lineno) + "): " + - handle(frame->f_code->co_name).cast() + "\n"; - frame = frame->f_back; - } - } -#endif - - return errorString; -} - -PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { - auto &instances = get_internals().registered_instances; - auto range = instances.equal_range(ptr); - for (auto it = range.first; it != range.second; ++it) { - for (auto vh : values_and_holders(it->second)) { - if (vh.type == type) - return handle((PyObject *) it->second); - } - } - return handle(); -} - -inline PyThreadState *get_thread_state_unchecked() { -#if defined(PYPY_VERSION) - return PyThreadState_GET(); -#elif PY_VERSION_HEX < 0x03000000 - return _PyThreadState_Current; -#elif PY_VERSION_HEX < 0x03050000 - return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); -#elif PY_VERSION_HEX < 0x03050200 - return (PyThreadState*) _PyThreadState_Current.value; -#else - return _PyThreadState_UncheckedGet(); -#endif -} - -// Forward declarations -inline void keep_alive_impl(handle nurse, handle patient); -inline PyObject *make_new_instance(PyTypeObject *type); - -class type_caster_generic { -public: - PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) - : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } - - type_caster_generic(const type_info *typeinfo) - : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } - - bool load(handle src, bool convert) { - return load_impl(src, convert); - } - - PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, - const detail::type_info *tinfo, - void *(*copy_constructor)(const void *), - void *(*move_constructor)(const void *), - const void *existing_holder = nullptr) { - if (!tinfo) // no type info: error will be set already - return handle(); - - void *src = const_cast(_src); - if (src == nullptr) - return none().release(); - - auto it_instances = get_internals().registered_instances.equal_range(src); - for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { - for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { - if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) - return handle((PyObject *) it_i->second).inc_ref(); - } - } - - auto inst = reinterpret_steal(make_new_instance(tinfo->type)); - auto wrapper = reinterpret_cast(inst.ptr()); - wrapper->owned = false; - void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); - - switch (policy) { - case return_value_policy::automatic: - case return_value_policy::take_ownership: - valueptr = src; - wrapper->owned = true; - break; - - case return_value_policy::automatic_reference: - case return_value_policy::reference: - valueptr = src; - wrapper->owned = false; - break; - - case return_value_policy::copy: - if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = copy, but the " - "object is non-copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::move: - if (move_constructor) - valueptr = move_constructor(src); - else if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = move, but the " - "object is neither movable nor copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::reference_internal: - valueptr = src; - wrapper->owned = false; - keep_alive_impl(inst, parent); - break; - - default: - throw cast_error("unhandled return_value_policy: should not happen!"); - } - - tinfo->init_instance(wrapper, existing_holder); - - return inst.release(); - } - - // Base methods for generic caster; there are overridden in copyable_holder_caster - void load_value(value_and_holder &&v_h) { - auto *&vptr = v_h.value_ptr(); - // Lazy allocation for unallocated values: - if (vptr == nullptr) { - auto *type = v_h.type ? v_h.type : typeinfo; - if (type->operator_new) { - vptr = type->operator_new(type->type_size); - } else { - #if defined(PYBIND11_CPP17) - if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) - vptr = ::operator new(type->type_size, - (std::align_val_t) type->type_align); - else - #endif - vptr = ::operator new(type->type_size); - } - } - value = vptr; - } - bool try_implicit_casts(handle src, bool convert) { - for (auto &cast : typeinfo->implicit_casts) { - type_caster_generic sub_caster(*cast.first); - if (sub_caster.load(src, convert)) { - value = cast.second(sub_caster.value); - return true; - } - } - return false; - } - bool try_direct_conversions(handle src) { - for (auto &converter : *typeinfo->direct_conversions) { - if (converter(src.ptr(), value)) - return true; - } - return false; - } - void check_holder_compat() {} - - PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { - auto caster = type_caster_generic(ti); - if (caster.load(src, false)) - return caster.value; - return nullptr; - } - - /// Try to load with foreign typeinfo, if available. Used when there is no - /// native typeinfo, or when the native one wasn't able to produce a value. - PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { - constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; - const auto pytype = src.get_type(); - if (!hasattr(pytype, local_key)) - return false; - - type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); - // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type - if (foreign_typeinfo->module_local_load == &local_load - || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) - return false; - - if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { - value = result; - return true; - } - return false; - } - - // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant - // bits of code between here and copyable_holder_caster where the two classes need different - // logic (without having to resort to virtual inheritance). - template - PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { - if (!src) return false; - if (!typeinfo) return try_load_foreign_module_local(src); - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - value = nullptr; - return true; - } - - auto &this_ = static_cast(*this); - this_.check_holder_compat(); - - PyTypeObject *srctype = Py_TYPE(src.ptr()); - - // Case 1: If src is an exact type match for the target type then we can reinterpret_cast - // the instance's value pointer to the target type: - if (srctype == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2: We have a derived class - else if (PyType_IsSubtype(srctype, typeinfo->type)) { - auto &bases = all_type_info(srctype); - bool no_cpp_mi = typeinfo->simple_type; - - // Case 2a: the python type is a Python-inherited derived class that inherits from just - // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of - // the right type and we can use reinterpret_cast. - // (This is essentially the same as case 2b, but because not using multiple inheritance - // is extremely common, we handle it specially to avoid the loop iterator and type - // pointer lookup overhead) - if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if - // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we - // can safely reinterpret_cast to the relevant pointer. - else if (bases.size() > 1) { - for (auto base : bases) { - if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); - return true; - } - } - } - - // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match - // in the registered bases, above, so try implicit casting (needed for proper C++ casting - // when MI is involved). - if (this_.try_implicit_casts(src, convert)) - return true; - } - - // Perform an implicit conversion - if (convert) { - for (auto &converter : typeinfo->implicit_conversions) { - auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); - if (load_impl(temp, false)) { - loader_life_support::add_patient(temp); - return true; - } - } - if (this_.try_direct_conversions(src)) - return true; - } - - // Failed to match local typeinfo. Try again with global. - if (typeinfo->module_local) { - if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { - typeinfo = gtype; - return load(src, false); - } - } - - // Global typeinfo has precedence over foreign module_local - return try_load_foreign_module_local(src); - } - - - // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast - // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair - // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). - PYBIND11_NOINLINE static std::pair src_and_type( - const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { - if (auto *tpi = get_type_info(cast_type)) - return {src, const_cast(tpi)}; - - // Not found, set error: - std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); - detail::clean_type_id(tname); - std::string msg = "Unregistered type : " + tname; - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return {nullptr, nullptr}; - } - - const type_info *typeinfo = nullptr; - const std::type_info *cpptype = nullptr; - void *value = nullptr; -}; - -/** - * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster - * needs to provide `operator T*()` and `operator T&()` operators. - * - * If the type supports moving the value away via an `operator T&&() &&` method, it should use - * `movable_cast_op_type` instead. - */ -template -using cast_op_type = - conditional_t>::value, - typename std::add_pointer>::type, - typename std::add_lvalue_reference>::type>; - -/** - * Determine suitable casting operator for a type caster with a movable value. Such a type caster - * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be - * called in appropriate contexts where the value can be moved rather than copied. - * - * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. - */ -template -using movable_cast_op_type = - conditional_t::type>::value, - typename std::add_pointer>::type, - conditional_t::value, - typename std::add_rvalue_reference>::type, - typename std::add_lvalue_reference>::type>>; - -// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when -// T is non-copyable, but code containing such a copy constructor fails to actually compile. -template struct is_copy_constructible : std::is_copy_constructible {}; - -// Specialization for types that appear to be copy constructible but also look like stl containers -// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if -// so, copy constructability depends on whether the value_type is copy constructible. -template struct is_copy_constructible, - std::is_same - >::value>> : is_copy_constructible {}; - -#if !defined(PYBIND11_CPP17) -// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the -// two types aren't themselves copy constructible). -template struct is_copy_constructible> - : all_of, is_copy_constructible> {}; -#endif - -NAMESPACE_END(detail) - -// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed -// to by `src` actually is an instance of some class derived from `itype`. -// If so, it sets `tinfo` to point to the std::type_info representing that derived -// type, and returns a pointer to the start of the most-derived object of that type -// (in which `src` is a subobject; this will be the same address as `src` in most -// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` -// and leaves `tinfo` at its default value of nullptr. -// -// The default polymorphic_type_hook just returns src. A specialization for polymorphic -// types determines the runtime type of the passed object and adjusts the this-pointer -// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear -// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is -// registered with pybind11, and this Animal is in fact a Dog). -// -// You may specialize polymorphic_type_hook yourself for types that want to appear -// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern -// in performance-sensitive applications, used most notably in LLVM.) -template -struct polymorphic_type_hook -{ - static const void *get(const itype *src, const std::type_info*&) { return src; } -}; -template -struct polymorphic_type_hook::value>> -{ - static const void *get(const itype *src, const std::type_info*& type) { - type = src ? &typeid(*src) : nullptr; - return dynamic_cast(src); - } -}; - -NAMESPACE_BEGIN(detail) - -/// Generic type caster for objects stored on the heap -template class type_caster_base : public type_caster_generic { - using itype = intrinsic_t; - -public: - static constexpr auto name = _(); - - type_caster_base() : type_caster_base(typeid(type)) { } - explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } - - static handle cast(const itype &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast(&src, policy, parent); - } - - static handle cast(itype &&src, return_value_policy, handle parent) { - return cast(&src, return_value_policy::move, parent); - } - - // Returns a (pointer, type_info) pair taking care of necessary type lookup for a - // polymorphic type (using RTTI by default, but can be overridden by specializing - // polymorphic_type_hook). If the instance isn't derived, returns the base version. - static std::pair src_and_type(const itype *src) { - auto &cast_type = typeid(itype); - const std::type_info *instance_type = nullptr; - const void *vsrc = polymorphic_type_hook::get(src, instance_type); - if (instance_type && !same_type(cast_type, *instance_type)) { - // This is a base pointer to a derived type. If the derived type is registered - // with pybind11, we want to make the full derived object available. - // In the typical case where itype is polymorphic, we get the correct - // derived pointer (which may be != base pointer) by a dynamic_cast to - // most derived type. If itype is not polymorphic, we won't get here - // except via a user-provided specialization of polymorphic_type_hook, - // and the user has promised that no this-pointer adjustment is - // required in that case, so it's OK to use static_cast. - if (const auto *tpi = get_type_info(*instance_type)) - return {vsrc, tpi}; - } - // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so - // don't do a cast - return type_caster_generic::src_and_type(src, cast_type, instance_type); - } - - static handle cast(const itype *src, return_value_policy policy, handle parent) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, policy, parent, st.second, - make_copy_constructor(src), make_move_constructor(src)); - } - - static handle cast_holder(const itype *src, const void *holder) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, return_value_policy::take_ownership, {}, st.second, - nullptr, nullptr, holder); - } - - template using cast_op_type = detail::cast_op_type; - - operator itype*() { return (type *) value; } - operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } - -protected: - using Constructor = void *(*)(const void *); - - /* Only enabled when the types are {copy,move}-constructible *and* when the type - does not have a private operator new implementation. */ - template ::value>> - static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { - return [](const void *arg) -> void * { - return new T(*reinterpret_cast(arg)); - }; - } - - template ::value>> - static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { - return [](const void *arg) -> void * { - return new T(std::move(*const_cast(reinterpret_cast(arg)))); - }; - } - - static Constructor make_copy_constructor(...) { return nullptr; } - static Constructor make_move_constructor(...) { return nullptr; } -}; - -template class type_caster : public type_caster_base { }; -template using make_caster = type_caster>; +template +class type_caster : public type_caster_base {}; +template +using make_caster = type_caster>; // Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T -template typename make_caster::template cast_op_type cast_op(make_caster &caster) { +template +typename make_caster::template cast_op_type cast_op(make_caster &caster) { return caster.operator typename make_caster::template cast_op_type(); } -template typename make_caster::template cast_op_type::type> +template +typename make_caster::template cast_op_type::type> cast_op(make_caster &&caster) { - return std::move(caster).operator - typename make_caster::template cast_op_type::type>(); + return std::move(caster).operator typename make_caster:: + template cast_op_type::type>(); } -template class type_caster> { +template +class type_caster> { private: using caster_t = make_caster; caster_t subcaster; - using subcaster_cast_op_type = typename caster_t::template cast_op_type; - static_assert(std::is_same::type &, subcaster_cast_op_type>::value, - "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); + using reference_t = type &; + using subcaster_cast_op_type = typename caster_t::template cast_op_type; + + static_assert( + std::is_same::type &, subcaster_cast_op_type>::value + || std::is_same::value, + "std::reference_wrapper caster requires T to have a caster with an " + "`operator T &()` or `operator const T &()`"); + public: bool load(handle src, bool convert) { return subcaster.load(src, convert); } static constexpr auto name = caster_t::name; - static handle cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { + static handle + cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { // It is definitely wrong to take ownership of this pointer, so mask that rvp - if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic) + if (policy == return_value_policy::take_ownership + || policy == return_value_policy::automatic) { policy = return_value_policy::automatic_reference; + } return caster_t::cast(&src.get(), policy, parent); } - template using cast_op_type = std::reference_wrapper; - operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } + template + using cast_op_type = std::reference_wrapper; + explicit operator std::reference_wrapper() { return cast_op(subcaster); } }; -#define PYBIND11_TYPE_CASTER(type, py_name) \ - protected: \ - type value; \ - public: \ - static constexpr auto name = py_name; \ - template >::value, int> = 0> \ - static handle cast(T_ *src, return_value_policy policy, handle parent) { \ - if (!src) return none().release(); \ - if (policy == return_value_policy::take_ownership) { \ - auto h = cast(std::move(*src), policy, parent); delete src; return h; \ - } else { \ - return cast(*src, policy, parent); \ - } \ - } \ - operator type*() { return &value; } \ - operator type&() { return value; } \ - operator type&&() && { return std::move(value); } \ - template using cast_op_type = pybind11::detail::movable_cast_op_type +#define PYBIND11_TYPE_CASTER(type, py_name) \ +protected: \ + type value; \ + \ +public: \ + static constexpr auto name = py_name; \ + template >::value, \ + int> = 0> \ + static ::pybind11::handle cast( \ + T_ *src, ::pybind11::return_value_policy policy, ::pybind11::handle parent) { \ + if (!src) \ + return ::pybind11::none().release(); \ + if (policy == ::pybind11::return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); \ + delete src; \ + return h; \ + } \ + return cast(*src, policy, parent); \ + } \ + operator type *() { return &value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &() { return value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &&() && { return std::move(value); } /* NOLINT(bugprone-macro-parentheses) */ \ + template \ + using cast_op_type = ::pybind11::detail::movable_cast_op_type - -template using is_std_char_type = any_of< - std::is_same, /* std::string */ - std::is_same, /* std::u16string */ - std::is_same, /* std::u32string */ - std::is_same /* std::wstring */ ->; +template +using is_std_char_type = any_of, /* std::string */ +#if defined(PYBIND11_HAS_U8STRING) + std::is_same, /* std::u8string */ +#endif + std::is_same, /* std::u16string */ + std::is_same, /* std::u32string */ + std::is_same /* std::wstring */ + >; template struct type_caster::value && !is_std_char_type::value>> { using _py_type_0 = conditional_t; - using _py_type_1 = conditional_t::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>; + using _py_type_1 = conditional_t::value, + _py_type_0, + typename std::make_unsigned<_py_type_0>::type>; using py_type = conditional_t::value, double, _py_type_1>; -public: +public: bool load(handle src, bool convert) { py_type py_value; - if (!src) + if (!src) { return false; - - if (std::is_floating_point::value) { - if (convert || PyFloat_Check(src.ptr())) - py_value = (py_type) PyFloat_AsDouble(src.ptr()); - else - return false; - } else if (PyFloat_Check(src.ptr())) { - return false; - } else if (std::is_unsigned::value) { - py_value = as_unsigned(src.ptr()); - } else { // signed integer: - py_value = sizeof(T) <= sizeof(long) - ? (py_type) PyLong_AsLong(src.ptr()) - : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); } - bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); - if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && - (py_value < (py_type) std::numeric_limits::min() || - py_value > (py_type) std::numeric_limits::max()))) { - bool type_error = py_err && PyErr_ExceptionMatches( -#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) - PyExc_SystemError +#if !defined(PYPY_VERSION) + auto index_check = [](PyObject *o) { return PyIndex_Check(o); }; #else - PyExc_TypeError + // In PyPy 7.3.3, `PyIndex_Check` is implemented by calling `__index__`, + // while CPython only considers the existence of `nb_index`/`__index__`. + auto index_check = [](PyObject *o) { return hasattr(o, "__index__"); }; #endif - ); + + if (std::is_floating_point::value) { + if (convert || PyFloat_Check(src.ptr())) { + py_value = (py_type) PyFloat_AsDouble(src.ptr()); + } else { + return false; + } + } else if (PyFloat_Check(src.ptr()) + || (!convert && !PYBIND11_LONG_CHECK(src.ptr()) && !index_check(src.ptr()))) { + return false; + } else { + handle src_or_index = src; + // PyPy: 7.3.7's 3.8 does not implement PyLong_*'s __index__ calls. +#if PY_VERSION_HEX < 0x03080000 || defined(PYPY_VERSION) + object index; + if (!PYBIND11_LONG_CHECK(src.ptr())) { // So: index_check(src.ptr()) + index = reinterpret_steal(PyNumber_Index(src.ptr())); + if (!index) { + PyErr_Clear(); + if (!convert) + return false; + } else { + src_or_index = index; + } + } +#endif + if (std::is_unsigned::value) { + py_value = as_unsigned(src_or_index.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type) PyLong_AsLong(src_or_index.ptr()) + : (py_type) PYBIND11_LONG_AS_LONGLONG(src_or_index.ptr()); + } + } + + // Python API reported an error + bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); + + // Check to see if the conversion is valid (integers should match exactly) + // Signed/unsigned checks happen elsewhere + if (py_err + || (std::is_integral::value && sizeof(py_type) != sizeof(T) + && py_value != (py_type) (T) py_value)) { PyErr_Clear(); - if (type_error && convert && PyNumber_Check(src.ptr())) { + if (py_err && convert && (PyNumber_Check(src.ptr()) != 0)) { auto tmp = reinterpret_steal(std::is_floating_point::value - ? PyNumber_Float(src.ptr()) - : PyNumber_Long(src.ptr())); + ? PyNumber_Float(src.ptr()) + : PyNumber_Long(src.ptr())); PyErr_Clear(); return load(tmp, false); } @@ -1020,62 +197,75 @@ public: return true; } - template + template static typename std::enable_if::value, handle>::type cast(U src, return_value_policy /* policy */, handle /* parent */) { return PyFloat_FromDouble((double) src); } - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) <= sizeof(long)), handle>::type + template + static typename std::enable_if::value && std::is_signed::value + && (sizeof(U) <= sizeof(long)), + handle>::type cast(U src, return_value_policy /* policy */, handle /* parent */) { return PYBIND11_LONG_FROM_SIGNED((long) src); } - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type + template + static typename std::enable_if::value && std::is_unsigned::value + && (sizeof(U) <= sizeof(unsigned long)), + handle>::type cast(U src, return_value_policy /* policy */, handle /* parent */) { return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src); } - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) > sizeof(long)), handle>::type + template + static typename std::enable_if::value && std::is_signed::value + && (sizeof(U) > sizeof(long)), + handle>::type cast(U src, return_value_policy /* policy */, handle /* parent */) { return PyLong_FromLongLong((long long) src); } - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) > sizeof(unsigned long)), handle>::type + template + static typename std::enable_if::value && std::is_unsigned::value + && (sizeof(U) > sizeof(unsigned long)), + handle>::type cast(U src, return_value_policy /* policy */, handle /* parent */) { return PyLong_FromUnsignedLongLong((unsigned long long) src); } - PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); + PYBIND11_TYPE_CASTER(T, const_name::value>("int", "float")); }; -template struct void_caster { +template +struct void_caster { public: bool load(handle src, bool) { - if (src && src.is_none()) + if (src && src.is_none()) { return true; + } return false; } static handle cast(T, return_value_policy /* policy */, handle /* parent */) { return none().inc_ref(); } - PYBIND11_TYPE_CASTER(T, _("None")); + PYBIND11_TYPE_CASTER(T, const_name("None")); }; -template <> class type_caster : public void_caster {}; +template <> +class type_caster : public void_caster {}; -template <> class type_caster : public type_caster { +template <> +class type_caster : public type_caster { public: using type_caster::cast; bool load(handle h, bool) { if (!h) { return false; - } else if (h.is_none()) { + } + if (h.is_none()) { value = nullptr; return true; } @@ -1087,7 +277,7 @@ public: } /* Check if this is a C++ type */ - auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr()); + const auto &bases = all_type_info((PyTypeObject *) type::handle_of(h).ptr()); if (bases.size() == 1) { // Only allowing loading from a single-value type value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); return true; @@ -1098,188 +288,247 @@ public: } static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { - if (ptr) + if (ptr) { return capsule(ptr).release(); - else - return none().inc_ref(); + } + return none().inc_ref(); } - template using cast_op_type = void*&; - operator void *&() { return value; } - static constexpr auto name = _("capsule"); + template + using cast_op_type = void *&; + explicit operator void *&() { return value; } + static constexpr auto name = const_name("capsule"); + private: void *value = nullptr; }; -template <> class type_caster : public void_caster { }; +template <> +class type_caster : public void_caster {}; -template <> class type_caster { +template <> +class type_caster { public: bool load(handle src, bool convert) { - if (!src) return false; - else if (src.ptr() == Py_True) { value = true; return true; } - else if (src.ptr() == Py_False) { value = false; return true; } - else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + if (!src) { + return false; + } + if (src.ptr() == Py_True) { + value = true; + return true; + } + if (src.ptr() == Py_False) { + value = false; + return true; + } + if (convert || (std::strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name) == 0)) { // (allow non-implicit conversion for numpy booleans) Py_ssize_t res = -1; if (src.is_none()) { - res = 0; // None is implicitly converted to False + res = 0; // None is implicitly converted to False } - #if defined(PYPY_VERSION) - // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists +#if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" attr exists else if (hasattr(src, PYBIND11_BOOL_ATTR)) { res = PyObject_IsTrue(src.ptr()); } - #else +#else // Alternate approach for CPython: this does the same as the above, but optimized // using the CPython API so as to avoid an unneeded attribute lookup. - else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { + else if (auto *tp_as_number = src.ptr()->ob_type->tp_as_number) { if (PYBIND11_NB_BOOL(tp_as_number)) { res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); } } - #endif +#endif if (res == 0 || res == 1) { - value = (bool) res; + value = (res != 0); return true; } + PyErr_Clear(); } return false; } static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { return handle(src ? Py_True : Py_False).inc_ref(); } - PYBIND11_TYPE_CASTER(bool, _("bool")); + PYBIND11_TYPE_CASTER(bool, const_name("bool")); }; // Helper class for UTF-{8,16,32} C++ stl strings: -template struct string_caster { +template +struct string_caster { using CharT = typename StringType::value_type; // Simplify life by being able to assume standard char sizes (the standard only guarantees // minimums, but Python requires exact sizes) - static_assert(!std::is_same::value || sizeof(CharT) == 1, "Unsupported char size != 1"); - static_assert(!std::is_same::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2"); - static_assert(!std::is_same::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4"); + static_assert(!std::is_same::value || sizeof(CharT) == 1, + "Unsupported char size != 1"); +#if defined(PYBIND11_HAS_U8STRING) + static_assert(!std::is_same::value || sizeof(CharT) == 1, + "Unsupported char8_t size != 1"); +#endif + static_assert(!std::is_same::value || sizeof(CharT) == 2, + "Unsupported char16_t size != 2"); + static_assert(!std::is_same::value || sizeof(CharT) == 4, + "Unsupported char32_t size != 4"); // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, - "Unsupported wchar_t size != 2/4"); + "Unsupported wchar_t size != 2/4"); static constexpr size_t UTF_N = 8 * sizeof(CharT); bool load(handle src, bool) { -#if PY_MAJOR_VERSION < 3 - object temp; -#endif handle load_src = src; if (!src) { return false; - } else if (!PyUnicode_Check(load_src.ptr())) { -#if PY_MAJOR_VERSION >= 3 - return load_bytes(load_src); -#else - if (sizeof(CharT) == 1) { - return load_bytes(load_src); - } - - // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false - if (!PYBIND11_BYTES_CHECK(load_src.ptr())) - return false; - - temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); - if (!temp) { PyErr_Clear(); return false; } - load_src = temp; -#endif + } + if (!PyUnicode_Check(load_src.ptr())) { + return load_raw(load_src); } - object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( - load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); - if (!utfNbytes) { PyErr_Clear(); return false; } + // For UTF-8 we avoid the need for a temporary `bytes` object by using + // `PyUnicode_AsUTF8AndSize`. + if (PYBIND11_SILENCE_MSVC_C4127(UTF_N == 8)) { + Py_ssize_t size = -1; + const auto *buffer + = reinterpret_cast(PyUnicode_AsUTF8AndSize(load_src.ptr(), &size)); + if (!buffer) { + PyErr_Clear(); + return false; + } + value = StringType(buffer, static_cast(size)); + return true; + } - const CharT *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); + auto utfNbytes + = reinterpret_steal(PyUnicode_AsEncodedString(load_src.ptr(), + UTF_N == 8 ? "utf-8" + : UTF_N == 16 ? "utf-16" + : "utf-32", + nullptr)); + if (!utfNbytes) { + PyErr_Clear(); + return false; + } + + const auto *buffer + = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); - if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 + // Skip BOM for UTF-16/32 + if (PYBIND11_SILENCE_MSVC_C4127(UTF_N > 8)) { + buffer++; + length--; + } value = StringType(buffer, length); // If we're loading a string_view we need to keep the encoded Python object alive: - if (IsView) + if (IsView) { loader_life_support::add_patient(utfNbytes); + } return true; } - static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { + static handle + cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { const char *buffer = reinterpret_cast(src.data()); - ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); + auto nbytes = ssize_t(src.size() * sizeof(CharT)); handle s = decode_utfN(buffer, nbytes); - if (!s) throw error_already_set(); + if (!s) { + throw error_already_set(); + } return s; } - PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); + PYBIND11_TYPE_CASTER(StringType, const_name(PYBIND11_STRING_NAME)); private: static handle decode_utfN(const char *buffer, ssize_t nbytes) { #if !defined(PYPY_VERSION) - return - UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) : - UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : - PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); + return UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) + : UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) + : PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); #else - // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version - // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a - // non-const char * arguments, which is also a nuisance, so bypass the whole thing by just - // passing the encoding as a string value, which works properly: - return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); + // PyPy segfaults when on PyUnicode_DecodeUTF16 (and possibly on PyUnicode_DecodeUTF32 as + // well), so bypass the whole thing by just passing the encoding as a string value, which + // works properly: + return PyUnicode_Decode(buffer, + nbytes, + UTF_N == 8 ? "utf-8" + : UTF_N == 16 ? "utf-16" + : "utf-32", + nullptr); #endif } - // When loading into a std::string or char*, accept a bytes object as-is (i.e. + // When loading into a std::string or char*, accept a bytes/bytearray object as-is (i.e. // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. // which supports loading a unicode from a str, doesn't take this path. template - bool load_bytes(enable_if_t src) { + bool load_raw(enable_if_t::value, handle> src) { if (PYBIND11_BYTES_CHECK(src.ptr())) { - // We were passed a Python 3 raw bytes; accept it into a std::string or char* + // We were passed raw bytes; accept it into a std::string or char* // without any encoding attempt. const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); - if (bytes) { - value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); - return true; + if (!bytes) { + pybind11_fail("Unexpected PYBIND11_BYTES_AS_STRING() failure."); } + value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } + if (PyByteArray_Check(src.ptr())) { + // We were passed a bytearray; accept it into a std::string or char* + // without any encoding attempt. + const char *bytearray = PyByteArray_AsString(src.ptr()); + if (!bytearray) { + pybind11_fail("Unexpected PyByteArray_AsString() failure."); + } + value = StringType(bytearray, (size_t) PyByteArray_Size(src.ptr())); + return true; } return false; } template - bool load_bytes(enable_if_t) { return false; } + bool load_raw(enable_if_t::value, handle>) { + return false; + } }; template -struct type_caster, enable_if_t::value>> +struct type_caster, + enable_if_t::value>> : string_caster> {}; #ifdef PYBIND11_HAS_STRING_VIEW template -struct type_caster, enable_if_t::value>> +struct type_caster, + enable_if_t::value>> : string_caster, true> {}; #endif // Type caster for C-style strings. We basically use a std::string type caster, but also add the // ability to use None as a nullptr char* (which the string caster doesn't allow). -template struct type_caster::value>> { +template +struct type_caster::value>> { using StringType = std::basic_string; - using StringCaster = type_caster; + using StringCaster = make_caster; StringCaster str_caster; bool none = false; CharT one_char = 0; + public: bool load(handle src, bool convert) { - if (!src) return false; + if (!src) { + return false; + } if (src.is_none()) { // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; + if (!convert) { + return false; + } none = true; return true; } @@ -1287,45 +536,58 @@ public: } static handle cast(const CharT *src, return_value_policy policy, handle parent) { - if (src == nullptr) return pybind11::none().inc_ref(); + if (src == nullptr) { + return pybind11::none().inc_ref(); + } return StringCaster::cast(StringType(src), policy, parent); } static handle cast(CharT src, return_value_policy policy, handle parent) { if (std::is_same::value) { handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); - if (!s) throw error_already_set(); + if (!s) { + throw error_already_set(); + } return s; } return StringCaster::cast(StringType(1, src), policy, parent); } - operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } - operator CharT&() { - if (none) + explicit operator CharT *() { + return none ? nullptr : const_cast(static_cast(str_caster).c_str()); + } + explicit operator CharT &() { + if (none) { throw value_error("Cannot convert None to a character"); + } auto &value = static_cast(str_caster); size_t str_len = value.size(); - if (str_len == 0) + if (str_len == 0) { throw value_error("Cannot convert empty string to a character"); + } // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that - // is too high, and one for multiple unicode characters (caught later), so we need to figure - // out how long the first encoded character is in bytes to distinguish between these two - // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those - // can fit into a single char value. - if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { - unsigned char v0 = static_cast(value[0]); - size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 - (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence - (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence - 4; // 0b11110xxx - start of 4-byte sequence + // is too high, and one for multiple unicode characters (caught later), so we need to + // figure out how long the first encoded character is in bytes to distinguish between these + // two errors. We also allow want to allow unicode characters U+0080 through U+00FF, as + // those can fit into a single char value. + if (PYBIND11_SILENCE_MSVC_C4127(StringCaster::UTF_N == 8) && str_len > 1 && str_len <= 4) { + auto v0 = static_cast(value[0]); + // low bits only: 0-127 + // 0b110xxxxx - start of 2-byte sequence + // 0b1110xxxx - start of 3-byte sequence + // 0b11110xxx - start of 4-byte sequence + size_t char0_bytes = (v0 & 0x80) == 0 ? 1 + : (v0 & 0xE0) == 0xC0 ? 2 + : (v0 & 0xF0) == 0xE0 ? 3 + : 4; if (char0_bytes == str_len) { // If we have a 128-255 value, we can decode it into a single char: if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx - one_char = static_cast(((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); + one_char = static_cast(((v0 & 3) << 6) + + (static_cast(value[1]) & 0x3F)); return one_char; } // Otherwise we have a single character, but it's > U+00FF @@ -1336,36 +598,42 @@ public: // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a // surrogate pair with total length 2 instantly indicates a range error (but not a "your // string was too long" error). - else if (StringCaster::UTF_N == 16 && str_len == 2) { + else if (PYBIND11_SILENCE_MSVC_C4127(StringCaster::UTF_N == 16) && str_len == 2) { one_char = static_cast(value[0]); - if (one_char >= 0xD800 && one_char < 0xE000) + if (one_char >= 0xD800 && one_char < 0xE000) { throw value_error("Character code point not in range(0x10000)"); + } } - if (str_len != 1) + if (str_len != 1) { throw value_error("Expected a character, but multi-character string found"); + } one_char = value[0]; return one_char; } - static constexpr auto name = _(PYBIND11_STRING_NAME); - template using cast_op_type = pybind11::detail::cast_op_type<_T>; + static constexpr auto name = const_name(PYBIND11_STRING_NAME); + template + using cast_op_type = pybind11::detail::cast_op_type<_T>; }; // Base implementation for std::tuple and std::pair -template class Tuple, typename... Ts> class tuple_caster { +template