[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)

This commit is contained in:
Philippe Tillet
2021-11-29 19:11:26 -08:00
committed by GitHub
parent 1296eb877b
commit c86ad9c9ab
4 changed files with 149 additions and 122 deletions

View File

@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue; continue;
} }
// argument is `constexpr` // argument is `constexpr`
py::object value = arg.attr("value"); if(py::hasattr(arg, "value")){
if(value){ py::object value = arg.attr("value");
py::object name = arg_names[i]; py::object name = arg_names[i];
constants[name] = value; constants[name] = value;
py::object repr = py::repr(value); py::object repr = py::repr(value);
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
cache_key += std::string(start, len); cache_key += std::string(start, len);
continue; continue;
} }
assert(false); std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
} }
cache_key += std::to_string(num_warps); cache_key += std::to_string(num_warps);
cache_key += std::to_string(num_stages); cache_key += std::to_string(num_stages);
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
CU_LAUNCH_PARAM_END CU_LAUNCH_PARAM_END
}; };
uint64_t _stream = PyLong_AsLong(stream.ptr()); uint64_t _stream = PyLong_AsLong(stream.ptr());
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, if(grid_0*grid_1*grid_2 > 0)
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream, drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
nullptr, config); _num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
return bin; return bin;
}); });
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
const std::string &args, int64_t shared_mem){ const std::string &args, int64_t shared_mem){
void* args_ptr = (void*)args.data(); void* args_ptr = (void*)args.data();
size_t args_size = args.size(); size_t args_size = args.size();
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
Py_BEGIN_ALLOW_THREADS
if(backend == HOST) if(backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == CUDA) if(backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == ROCM) if(backend == ROCM)
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
Py_END_ALLOW_THREADS
}); });

View File

@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
# test while # test while
# --------------- # ---------------
# ---------------
# test default
# ---------------
#TODO: can't be local to test_default
@triton.jit
def _impl(value = 10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
@triton.jit
def _kernel(ret0, ret1, value):
tl.store(ret0, _impl())
tl.store(ret1, _impl(value))
_kernel[(1,)](ret0, ret1, value)
assert ret0.item() == 10
assert ret1.item() == value
# --------------- # ---------------
# test noop # test noop
#---------------- #----------------

View File

@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
def visit_FunctionDef(self, node, inline=False, arg_values=None): def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args) arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope # store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs self.lscope[kwarg_names] = self.kwargs
# initialize function # initialize function
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
iterator = self.visit(node.iter.func) iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']: if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported') raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# create nodes # create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store()) st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
@@ -483,6 +508,7 @@ class CompilationError(Exception):
self.message += '\n' + ' ' * node.col_offset + '^' self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err) self.message += '\n Error: ' + str(err)
super().__init__(self.message) super().__init__(self.message)
self.args = (src, node, err)
class OutOfResources(Exception): class OutOfResources(Exception):
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
f'Required: {required}'\ f'Required: {required}'\
f'Hardware limit: {limit}' f'Hardware limit: {limit}'
super().__init__(self.message) super().__init__(self.message)
self.args = (required, limit, name)
class Kernel: class Kernel:
@@ -805,7 +832,10 @@ class JITFunction:
# information of wrapped function # information of wrapped function
self.fn = fn self.fn = fn
self.module = fn.__module__ self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version self.version = version
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
self.do_not_specialize = [] if do_not_specialize is None else\ self.do_not_specialize = [] if do_not_specialize is None else\
@@ -829,7 +859,7 @@ class JITFunction:
if not hasattr(self.fn, 'hash'): if not hasattr(self.fn, 'hash'):
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src) dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
dependencies_finder.visit(self.parse()) dependencies_finder.visit(self.parse())
self.fn.hash = dependencies_finder.ret self.fn.hash = dependencies_finder.ret + version_key()
return self.fn.hash return self.fn.hash
# we do not parse `src` in the constructor because # we do not parse `src` in the constructor because
@@ -848,6 +878,7 @@ class JITFunction:
lscope = generator.lscope.copy() lscope = generator.lscope.copy()
values = generator.module.get_values().copy() values = generator.module.get_values().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
generator.gscope = gscope generator.gscope = gscope
generator.lscope = lscope generator.lscope = lscope

View File

@@ -7,98 +7,56 @@ from . import core as tl
# 2. multiply_low_high is currently inefficient. # 2. multiply_low_high is currently inefficient.
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float # 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
@triton.jit # -------------------
def PHILOX_KEY_A(): # randint
# 0x9E3779B9 # -------------------
return -1640531527
@triton.jit
def PHILOX_KEY_B():
# 0xBB67AE85
return -1150833019
@triton.jit
def PHILOX_ROUND_A():
# 0xD2511F53
return -766435501
@triton.jit
def PHILOX_ROUND_B():
# 0xCD9E8D57
return -845247145
@triton.jit @triton.jit
def hacky_to_uint64(x): def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit @triton.jit
def single_round(c0, c1, c2, c3, k0, k1): def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
A = PHILOX_ROUND_A() """
B = PHILOX_ROUND_B() Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
_c0, _c2 = c0, c2 """
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 for _ in range(n_rounds):
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 # update random state
c1 = B * _c2 A = PHILOX_ROUND_A
c3 = A * _c0 B = PHILOX_ROUND_B
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
# raise key
k0 = k0 + PHILOX_KEY_A
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3 return c0, c1, c2, c3
@triton.jit @triton.jit
def raise_key(k0, k1): def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1):
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
return c0, c1, c2, c3
@triton.jit
def uint32_to_uniform_float(x):
""" """
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). Given a :code:`seed` scalar and an :code:`offset` block, returns a single
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly block of random :code:`int32`.
covers all the possible values it can take.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
""" """
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647. ret, _, _, _ = randint4x(seed, offset, n_rounds)
x = tl.where(x < 0, -x - 1, x) return ret
return x * max
@triton.jit @triton.jit
def pair_uniform_to_normal(u1, u2): def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randint4x(seed, offset):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, returns four Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`. blocks of random :code:`int32`.
@@ -114,27 +72,26 @@ def randint4x(seed, offset):
seed = hacky_to_uint64(seed) # uint will solve this seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32)
return philox_f(offset, z, z, z, seed_lo, seed_hi) return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
# -------------------
# rand
# -------------------
@triton.jit @triton.jit
def randint(seed, offset): def uint32_to_uniform_float(x):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, returns a single Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
block of random :code:`int32`. This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
""" """
ret, _, _, _ = randint4x(seed, offset) max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
return ret x = tl.where(x < 0, -x - 1, x)
return x * max
@triton.jit @triton.jit
def rand(seed, offset): def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)` returns a block of random :code:`float32` in :math:`U(0, 1)`
@@ -142,28 +99,11 @@ def rand(seed, offset):
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
""" """
source = randint(seed, offset) source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source) return uint32_to_uniform_float(source)
@triton.jit @triton.jit
def randn(seed, offset): def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def rand4x(seed, offsets):
""" """
Given a :code:`seed` scalar and an :code:`offsets` block, Given a :code:`seed` scalar and an :code:`offsets` block,
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)` returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
@@ -171,16 +111,42 @@ def rand4x(seed, offsets):
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
""" """
i1, i2, i3, i4 = randint4x(seed, offsets) i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
u1 = uint32_to_uniform_float(i1) u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2) u2 = uint32_to_uniform_float(i2)
u3 = uint32_to_uniform_float(i3) u3 = uint32_to_uniform_float(i3)
u4 = uint32_to_uniform_float(i4) u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4 return u1, u2, u3, u4
# -------------------
# randn
# -------------------
@triton.jit @triton.jit
def randn4x(seed, offset): def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
@@ -188,7 +154,7 @@ def randn4x(seed, offset):
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
""" """
u1, u2, u3, u4 = rand4x(seed, offset) u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
n1, n2 = pair_uniform_to_normal(u1, u2) n1, n2 = pair_uniform_to_normal(u1, u2)
n3, n4 = pair_uniform_to_normal(u3, u4) n3, n4 = pair_uniform_to_normal(u3, u4)
return n1, n2, n3, n4 return n1, n2, n3, n4