diff --git a/python/src/triton.cc b/python/src/triton.cc index 2c165dd06..26c233287 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } // argument is `constexpr` - py::object value = arg.attr("value"); - if(value){ + if(py::hasattr(arg, "value")){ + py::object value = arg.attr("value"); py::object name = arg_names[i]; constants[name] = value; py::object repr = py::repr(value); @@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f cache_key += std::string(start, len); continue; } - assert(false); + std::string ty_str = arg.attr("__class__").attr("__name__").cast(); + std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." + + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; + throw std::runtime_error(err_msg); } cache_key += std::to_string(num_warps); cache_key += std::to_string(num_stages); @@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) { CU_LAUNCH_PARAM_END }; uint64_t _stream = PyLong_AsLong(stream.ptr()); - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, - _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, - nullptr, config); + if(grid_0*grid_1*grid_2 > 0) + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, + nullptr, config); return bin; }); @@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) { const std::string &args, int64_t shared_mem){ void* args_ptr = (void*)args.data(); size_t args_size = args.size(); + // release the gil in case the enqueue blocks + // cuda will block if too many ops are enqueued + Py_BEGIN_ALLOW_THREADS if(backend == HOST) host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == CUDA) cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); if(backend == ROCM) hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); + Py_END_ALLOW_THREADS }); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9354ec233..98c8c34fa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -634,6 +634,28 @@ def test_load_cache_modifier(cache): # test while # --------------- +# --------------- +# test default +# --------------- +#TODO: can't be local to test_default +@triton.jit +def _impl(value = 10): + return value + +def test_default(): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') + ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') + + @triton.jit + def _kernel(ret0, ret1, value): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1,)](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + # --------------- # test noop #---------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 84d77795c..b2fded136 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor): def visit_FunctionDef(self, node, inline=False, arg_values=None): arg_names, kwarg_names = self.visit(node.args) + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i-1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + self.visit(init_node) # store keyword arguments in local scope self.lscope[kwarg_names] = self.kwargs # initialize function @@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor): iterator = self.visit(node.iter.func) if iterator != self.builtins['range']: raise RuntimeError('Only `range` iterator currently supported') + # static for loops: all iterator arguments are constexpr + iter_args = [self.visit(arg) for arg in node.iter.args] + is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) + if is_static: + st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + iter_args = [arg.value for arg in iter_args] + range = iterator(*iter_args) + if len(range) <= 10: + for i in iterator(*iter_args): + self.lscope[node.target.id] = triton.language.constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return # create nodes st_target = ast.Name(id=node.target.id, ctx=ast.Store()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) @@ -483,6 +508,7 @@ class CompilationError(Exception): self.message += '\n' + ' ' * node.col_offset + '^' self.message += '\n Error: ' + str(err) super().__init__(self.message) + self.args = (src, node, err) class OutOfResources(Exception): @@ -491,6 +517,7 @@ class OutOfResources(Exception): f'Required: {required}'\ f'Hardware limit: {limit}' super().__init__(self.message) + self.args = (required, limit, name) class Kernel: @@ -805,7 +832,10 @@ class JITFunction: # information of wrapped function self.fn = fn self.module = fn.__module__ - self.arg_names = inspect.getfullargspec(fn).args + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.arg_defaults = [v.default for v in signature.parameters.values()] + self.version = version self.src = textwrap.dedent(inspect.getsource(fn)) self.do_not_specialize = [] if do_not_specialize is None else\ @@ -829,7 +859,7 @@ class JITFunction: if not hasattr(self.fn, 'hash'): dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) dependencies_finder.visit(self.parse()) - self.fn.hash = dependencies_finder.ret + self.fn.hash = dependencies_finder.ret + version_key() return self.fn.hash # we do not parse `src` in the constructor because @@ -848,6 +878,7 @@ class JITFunction: lscope = generator.lscope.copy() values = generator.module.get_values().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ + generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) generator.gscope = gscope generator.lscope = lscope diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 9bb29588a..a831af487 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -7,98 +7,56 @@ from . import core as tl # 2. multiply_low_high is currently inefficient. # 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float +PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox -@triton.jit -def PHILOX_KEY_A(): - # 0x9E3779B9 - return -1640531527 - - -@triton.jit -def PHILOX_KEY_B(): - # 0xBB67AE85 - return -1150833019 - - -@triton.jit -def PHILOX_ROUND_A(): - # 0xD2511F53 - return -766435501 - - -@triton.jit -def PHILOX_ROUND_B(): - # 0xCD9E8D57 - return -845247145 +# ------------------- +# randint +# ------------------- @triton.jit def hacky_to_uint64(x): return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) - @triton.jit -def single_round(c0, c1, c2, c3, k0, k1): - A = PHILOX_ROUND_A() - B = PHILOX_ROUND_B() - _c0, _c2 = c0, c2 - c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 - c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 - c1 = B * _c2 - c3 = A * _c0 +def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 + c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B return c0, c1, c2, c3 - @triton.jit -def raise_key(k0, k1): - return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B()) - -@triton.jit -def philox_f(c0, c1, c2, c3, k0, k1): - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - k0, k1 = raise_key(k0, k1) - c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) - return c0, c1, c2, c3 - - - -@triton.jit -def uint32_to_uniform_float(x): +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ - Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). - This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly - covers all the possible values it can take. + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. """ - max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. - x = tl.where(x < 0, -x - 1, x) - return x * max + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret @triton.jit -def pair_uniform_to_normal(u1, u2): - """Box-Muller transform""" - u1 = tl.maximum(1.0e-7, u1) - th = 6.283185307179586 * u2 - r = tl.sqrt(-2.0 * tl.log(u1)) - return r * tl.cos(th), r * tl.sin(th) - - -@triton.jit -def randint4x(seed, offset): +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns four blocks of random :code:`int32`. @@ -114,27 +72,26 @@ def randint4x(seed, offset): seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32) - return philox_f(offset, z, z, z, seed_lo, seed_hi) + return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds) +# ------------------- +# rand +# ------------------- + @triton.jit -def randint(seed, offset): +def uint32_to_uniform_float(x): """ - Given a :code:`seed` scalar and an :code:`offset` block, returns a single - block of random :code:`int32`. - - If you need multiple streams of random numbers, - using `randint4x` is likely to be faster than calling `randint` 4 times. - - :param seed: The seed for generating random numbers. - :param offsets: The offsets to generate random numbers for. + Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). + This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly + covers all the possible values it can take. """ - ret, _, _, _ = randint4x(seed, offset) - return ret - + max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. + x = tl.where(x < 0, -x - 1, x) + return x * max @triton.jit -def rand(seed, offset): +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns a block of random :code:`float32` in :math:`U(0, 1)` @@ -142,28 +99,11 @@ def rand(seed, offset): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - source = randint(seed, offset) + source = randint(seed, offset, n_rounds) return uint32_to_uniform_float(source) - @triton.jit -def randn(seed, offset): - """ - Given a :code:`seed` scalar and an :code:`offset` block, - returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` - - :param seed: The seed for generating random numbers. - :param offsets: The offsets to generate random numbers for. - """ - i1, i2, _, _ = randint4x(seed, offset) - u1 = uint32_to_uniform_float(i1) - u2 = uint32_to_uniform_float(i2) - n1, _ = pair_uniform_to_normal(u1, u2) - return n1 - - -@triton.jit -def rand4x(seed, offsets): +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offsets` block, returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)` @@ -171,16 +111,42 @@ def rand4x(seed, offsets): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - i1, i2, i3, i4 = randint4x(seed, offsets) + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) u1 = uint32_to_uniform_float(i1) u2 = uint32_to_uniform_float(i2) u3 = uint32_to_uniform_float(i3) u4 = uint32_to_uniform_float(i4) return u1, u2, u3, u4 +# ------------------- +# randn +# ------------------- @triton.jit -def randn4x(seed, offset): +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = tl.sqrt(-2.0 * tl.log(u1)) + return r * tl.cos(th), r * tl.sin(th) + +@triton.jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + +@triton.jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` @@ -188,7 +154,7 @@ def randn4x(seed, offset): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - u1, u2, u3, u4 = rand4x(seed, offset) + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) n1, n2 = pair_uniform_to_normal(u1, u2) n3, n4 = pair_uniform_to_normal(u3, u4) return n1, n2, n3, n4