[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
py::object value = arg.attr("value");
|
||||
if(value){
|
||||
if(py::hasattr(arg, "value")){
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
py::object repr = py::repr(value);
|
||||
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
assert(false);
|
||||
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
cache_key += std::to_string(num_warps);
|
||||
cache_key += std::to_string(num_stages);
|
||||
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
if(grid_0*grid_1*grid_2 > 0)
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
return bin;
|
||||
});
|
||||
|
||||
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
|
||||
const std::string &args, int64_t shared_mem){
|
||||
void* args_ptr = (void*)args.data();
|
||||
size_t args_size = args.size();
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
if(backend == HOST)
|
||||
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
Py_END_ALLOW_THREADS
|
||||
});
|
||||
|
||||
|
||||
|
@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
|
||||
# test while
|
||||
# ---------------
|
||||
|
||||
# ---------------
|
||||
# test default
|
||||
# ---------------
|
||||
#TODO: can't be local to test_default
|
||||
@triton.jit
|
||||
def _impl(value = 10):
|
||||
return value
|
||||
|
||||
def test_default():
|
||||
value = 5
|
||||
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(ret0, ret1, value):
|
||||
tl.store(ret0, _impl())
|
||||
tl.store(ret1, _impl(value))
|
||||
|
||||
_kernel[(1,)](ret0, ret1, value)
|
||||
assert ret0.item() == 10
|
||||
assert ret1.item() == value
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
#----------------
|
||||
|
@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i-1]
|
||||
annotation = arg_node.annotation
|
||||
name = arg_node.arg
|
||||
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||
if annotation is None:
|
||||
init_node = ast.Assign(targets=[st_target], value=default_value)
|
||||
else:
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# store keyword arguments in local scope
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
for i in iterator(*iter_args):
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# create nodes
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||
@@ -483,6 +508,7 @@ class CompilationError(Exception):
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
self.args = (src, node, err)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
|
||||
f'Required: {required}'\
|
||||
f'Hardware limit: {limit}'
|
||||
super().__init__(self.message)
|
||||
self.args = (required, limit, name)
|
||||
|
||||
|
||||
class Kernel:
|
||||
@@ -805,7 +832,10 @@ class JITFunction:
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||
@@ -829,7 +859,7 @@ class JITFunction:
|
||||
if not hasattr(self.fn, 'hash'):
|
||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.fn.hash = dependencies_finder.ret
|
||||
self.fn.hash = dependencies_finder.ret + version_key()
|
||||
return self.fn.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
@@ -848,6 +878,7 @@ class JITFunction:
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
|
@@ -7,98 +7,56 @@ from . import core as tl
|
||||
# 2. multiply_low_high is currently inefficient.
|
||||
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
|
||||
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
|
||||
@triton.jit
|
||||
def PHILOX_KEY_A():
|
||||
# 0x9E3779B9
|
||||
return -1640531527
|
||||
|
||||
|
||||
@triton.jit
|
||||
def PHILOX_KEY_B():
|
||||
# 0xBB67AE85
|
||||
return -1150833019
|
||||
|
||||
|
||||
@triton.jit
|
||||
def PHILOX_ROUND_A():
|
||||
# 0xD2511F53
|
||||
return -766435501
|
||||
|
||||
|
||||
@triton.jit
|
||||
def PHILOX_ROUND_B():
|
||||
# 0xCD9E8D57
|
||||
return -845247145
|
||||
# -------------------
|
||||
# randint
|
||||
# -------------------
|
||||
|
||||
@triton.jit
|
||||
def hacky_to_uint64(x):
|
||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def single_round(c0, c1, c2, c3, k0, k1):
|
||||
A = PHILOX_ROUND_A()
|
||||
B = PHILOX_ROUND_B()
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
"""
|
||||
for _ in range(n_rounds):
|
||||
# update random state
|
||||
A = PHILOX_ROUND_A
|
||||
B = PHILOX_ROUND_B
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
# raise key
|
||||
k0 = k0 + PHILOX_KEY_A
|
||||
k1 = k1 + PHILOX_KEY_B
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def raise_key(k0, k1):
|
||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||
|
||||
@triton.jit
|
||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
k0, k1 = raise_key(k0, k1)
|
||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def uint32_to_uniform_float(x):
|
||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||
covers all the possible values it can take.
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||
block of random :code:`int32`.
|
||||
|
||||
If you need multiple streams of random numbers,
|
||||
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * max
|
||||
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
||||
return ret
|
||||
|
||||
@triton.jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
th = 6.283185307179586 * u2
|
||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randint4x(seed, offset):
|
||||
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||
blocks of random :code:`int32`.
|
||||
@@ -114,27 +72,26 @@ def randint4x(seed, offset):
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
return philox_f(offset, z, z, z, seed_lo, seed_hi)
|
||||
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
|
||||
|
||||
|
||||
# -------------------
|
||||
# rand
|
||||
# -------------------
|
||||
|
||||
@triton.jit
|
||||
def randint(seed, offset):
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||
block of random :code:`int32`.
|
||||
|
||||
If you need multiple streams of random numbers,
|
||||
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||
covers all the possible values it can take.
|
||||
"""
|
||||
ret, _, _, _ = randint4x(seed, offset)
|
||||
return ret
|
||||
|
||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * max
|
||||
|
||||
@triton.jit
|
||||
def rand(seed, offset):
|
||||
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`U(0, 1)`
|
||||
@@ -142,28 +99,11 @@ def rand(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
source = randint(seed, offset)
|
||||
source = randint(seed, offset, n_rounds)
|
||||
return uint32_to_uniform_float(source)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def randn(seed, offset):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
i1, i2, _, _ = randint4x(seed, offset)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||
return n1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rand4x(seed, offsets):
|
||||
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
|
||||
@@ -171,16 +111,42 @@ def rand4x(seed, offsets):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
i1, i2, i3, i4 = randint4x(seed, offsets)
|
||||
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
u3 = uint32_to_uniform_float(i3)
|
||||
u4 = uint32_to_uniform_float(i4)
|
||||
return u1, u2, u3, u4
|
||||
|
||||
# -------------------
|
||||
# randn
|
||||
# -------------------
|
||||
|
||||
@triton.jit
|
||||
def randn4x(seed, offset):
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
th = 6.283185307179586 * u2
|
||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
@triton.jit
|
||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
|
||||
u1 = uint32_to_uniform_float(i1)
|
||||
u2 = uint32_to_uniform_float(i2)
|
||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||
return n1
|
||||
|
||||
@triton.jit
|
||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||
@@ -188,7 +154,7 @@ def randn4x(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
u1, u2, u3, u4 = rand4x(seed, offset)
|
||||
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
|
||||
n1, n2 = pair_uniform_to_normal(u1, u2)
|
||||
n3, n4 = pair_uniform_to_normal(u3, u4)
|
||||
return n1, n2, n3, n4
|
||||
|
Reference in New Issue
Block a user