From 4c97d1ecd778c287e57617eb01c4edd03cf855f6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 20 Jan 2022 10:55:59 -0800 Subject: [PATCH] [FRONTEND] Bunch of fixes here and there (#436) --- lib/ir/dispatch.cc | 7 +++- python/src/triton.cc | 6 ++- python/test/unit/language/test_core.py | 2 +- python/test/unit/runtime/test_cache.py | 6 +-- python/triton/code_gen.py | 52 ++++++++++++++------------ python/triton/language/core.py | 19 +++++++++- python/triton/language/random.py | 18 ++++++--- 7 files changed, 71 insertions(+), 39 deletions(-) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 73d4ddc94..69c76b5e5 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -745,8 +745,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir: x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); } - if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty()) - throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty()); + ir::type* x_ty = x->get_type()->get_scalar_ty(); + ir::type* y_ty = y->get_type()->get_scalar_ty(); + ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); + x = dispatch::cast(x, ty, builder); + y = dispatch::cast(y, ty, builder); return builder->create_select(condition, x, y); } diff --git a/python/src/triton.cc b/python/src/triton.cc index cb6923afe..8ceb14200 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -329,7 +329,6 @@ void init_triton_runtime(py::module &&m) { // cuda will block if too many ops are enqueued Py_BEGIN_ALLOW_THREADS - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); @@ -466,6 +465,9 @@ std::tuple hip_load_binary(const std::string& name, asm_map_ std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, uint64_t device, int num_warps, int num_stages, asm_map_t &asm_map){ + + int n_shared_bytes; + Py_BEGIN_ALLOW_THREADS llvm::LLVMContext ctx; // device properties CUdevice dev = (CUdevice)device; @@ -476,7 +478,6 @@ std::tuple cu_compile_ttir(const std::string& name, drv::dispatch::cuDriverGetVersion(&version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); - int n_shared_bytes; auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); std::string tmp; llvm::raw_string_ostream llir(tmp); @@ -492,6 +493,7 @@ std::tuple cu_compile_ttir(const std::string& name, py::bytes bytes(cubin); asm_map["cubin"] = bytes; } + Py_END_ALLOW_THREADS return std::make_tuple(name, asm_map, n_shared_bytes); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 64f60e260..4d6e32aa6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -79,7 +79,7 @@ def to_numpy(x): def patch_kernel(template, to_replace): - kernel = copy.deepcopy(template) + kernel = triton.JITFunction(template.fn) for key, value in to_replace.items(): kernel.src = kernel.src.replace(key, value) return kernel diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 51c69b5b6..48797b51a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -39,9 +39,9 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr): def apply_src_change(target, old, new): - delattr(kernel.fn, 'hash') - delattr(function_1.fn, 'hash') - delattr(function_2.fn, 'hash') + kernel.hash = None + function_1.hash = None + function_2.hash = None function_1.src = function_1.src.replace(old, new) target.src = target.src.replace(old, new) ret = target.cache_key diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 960df4efc..347635e32 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -651,7 +651,7 @@ class Kernel: prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self.fn into code-generator object - gscope = self.fn.fn.__globals__ + gscope = self.fn.__globals__ generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) try: generator.visit(self.fn.parse()) @@ -723,7 +723,7 @@ class Kernel: pickle.dump({"binary": binary, "key": key}, f) os.rename(bin_cache_path + ".tmp", bin_cache_path) if JITFunction.cache_hook is not None: - name = self.fn.fn.__name__ + name = self.fn.__name__ info = key.split('-')[-3:] num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] # make signature human-readable @@ -885,8 +885,6 @@ def version_key(): ptxas_version = '' return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) -# 3 - class DependenciesFinder(ast.NodeVisitor): @@ -910,17 +908,14 @@ class DependenciesFinder(ast.NodeVisitor): func = self.visit(node.func) if func is None: return - if isinstance(func, triton.JITFunction): - func = func.fn - module = inspect.getmodule(func) - if module and module.__name__.startswith('triton.'): - return if inspect.isbuiltin(func): return - if not hasattr(func, 'hash'): - src = textwrap.dedent(inspect.getsource(func)) - tree = ast.parse(src) - finder = DependenciesFinder(func.__globals__, src) + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, triton.JITFunction) + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) finder.visit(tree) func.hash = finder.ret self.ret = (self.ret + func.hash).encode("utf-8") @@ -941,10 +936,12 @@ class JITFunction: self.version = version self.src = textwrap.dedent(inspect.getsource(fn)) - self.do_not_specialize = [] if do_not_specialize is None else\ - [self.arg_names.index(arg) for arg in do_not_specialize] + self.src = self.src[self.src.find("def"):] + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] # cache for callable driver objects (e.g. CUkernel) self.bin_cache = dict() + self.hash = None # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel_decorators = [] @@ -954,15 +951,19 @@ class JITFunction: self.__annotations__ = fn.__annotations__ # forward docs self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ @property @functools.lru_cache() def cache_key(self): - if not hasattr(self.fn, 'hash'): - dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) dependencies_finder.visit(self.parse()) - self.fn.hash = dependencies_finder.ret + version_key() - return self.fn.hash + self.hash = dependencies_finder.ret + version_key() + return self.hash # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. @@ -974,14 +975,20 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - def __call__(self, *args, generator: CodeGenerator): + def __call__(self, *args, generator: CodeGenerator, **kwargs): try: + from inspect import getcallargs + arg_values = getcallargs(self.fn, *args, **kwargs) + arg_values = [arg_values[name] for name in self.arg_names] + arg_values = [arg if isinstance(arg, triton.language.block) + else triton.language.constexpr(arg) for arg in arg_values] + gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() - ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) @@ -1001,8 +1008,7 @@ class JITFunction: self.kernel = None super(JITFunction, self).__setattr__(name, value) if name == 'src': - if hasattr(self.fn, 'hash'): - delattr(self.fn, 'hash') + self.hash = None JITFunction.cache_key.fget.cache_clear() def _init_kernel(self): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 6f1bf3ba7..df3b1f4cf 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -168,6 +168,8 @@ class block: self.numel = constexpr(self.numel) # Data-type wrapper self.dtype = block._init_dtype(self.handle.type.scalar) + # Shape is a constexpr + self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: # ex. "float32[3,4]" @@ -297,7 +299,7 @@ class block: if sl is None: dst_shape.append(1) elif sl == slice(None, None, None): - dst_shape.append(src_shape[curr]) + dst_shape.append(src_shape[curr].value) curr += 1 ret = frontend.reshape(self, dst_shape, _builder) return ret @@ -320,8 +322,15 @@ class constexpr: """ def __init__(self, value): - self.value = value + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + # def __add__(self, other): return self.value + other.value @@ -516,6 +525,7 @@ def reshape(input, shape, _builder=None): :type shape: Tuple[int] """ + shape = [x.value for x in shape] return frontend.reshape(input, shape, _builder) @@ -908,3 +918,8 @@ def swizzle2d(i, j, size_i, size_j, size_g): new_i = off_i + (ij % size_g) new_j = (ij % size_gj) // size_g return new_i, new_j + + +@triton.jit +def zeros_like(input): + return zeros(input.shape, input.dtype) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 69d7f4c4d..d5691bb72 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -13,7 +13,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox @triton.jit -def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). """ @@ -32,6 +32,14 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): return c0, c1, c2, c3 +@triton.jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = seed.to(tl.uint64) + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + @triton.jit def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ @@ -60,11 +68,9 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting - seed = seed.to(tl.uint64) - seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) - seed_lo = (seed & 0xffffffff).to(tl.uint32) - return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) # -------------------