[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)

This commit is contained in:
Philippe Tillet
2021-11-29 19:11:26 -08:00
committed by GitHub
parent 1296eb877b
commit c86ad9c9ab
4 changed files with 149 additions and 122 deletions

View File

@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue;
}
// argument is `constexpr`
py::object value = arg.attr("value");
if(value){
if(py::hasattr(arg, "value")){
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
py::object repr = py::repr(value);
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
cache_key += std::string(start, len);
continue;
}
assert(false);
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
cache_key += std::to_string(num_warps);
cache_key += std::to_string(num_stages);
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
CU_LAUNCH_PARAM_END
};
uint64_t _stream = PyLong_AsLong(stream.ptr());
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
if(grid_0*grid_1*grid_2 > 0)
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
return bin;
});
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
const std::string &args, int64_t shared_mem){
void* args_ptr = (void*)args.data();
size_t args_size = args.size();
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
Py_BEGIN_ALLOW_THREADS
if(backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == ROCM)
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
Py_END_ALLOW_THREADS
});

View File

@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
# test while
# ---------------
# ---------------
# test default
# ---------------
#TODO: can't be local to test_default
@triton.jit
def _impl(value = 10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
@triton.jit
def _kernel(ret0, ret1, value):
tl.store(ret0, _impl())
tl.store(ret1, _impl(value))
_kernel[(1,)](ret0, ret1, value)
assert ret0.item() == 10
assert ret1.item() == value
# ---------------
# test noop
#----------------

View File

@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
@@ -483,6 +508,7 @@ class CompilationError(Exception):
self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
self.args = (src, node, err)
class OutOfResources(Exception):
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
f'Required: {required}'\
f'Hardware limit: {limit}'
super().__init__(self.message)
self.args = (required, limit, name)
class Kernel:
@@ -805,7 +832,10 @@ class JITFunction:
# information of wrapped function
self.fn = fn
self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.src = textwrap.dedent(inspect.getsource(fn))
self.do_not_specialize = [] if do_not_specialize is None else\
@@ -829,7 +859,7 @@ class JITFunction:
if not hasattr(self.fn, 'hash'):
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.fn.hash = dependencies_finder.ret
self.fn.hash = dependencies_finder.ret + version_key()
return self.fn.hash
# we do not parse `src` in the constructor because
@@ -848,6 +878,7 @@ class JITFunction:
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
generator.gscope = gscope
generator.lscope = lscope

View File

@@ -7,98 +7,56 @@ from . import core as tl
# 2. multiply_low_high is currently inefficient.
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
@triton.jit
def PHILOX_KEY_A():
# 0x9E3779B9
return -1640531527
@triton.jit
def PHILOX_KEY_B():
# 0xBB67AE85
return -1150833019
@triton.jit
def PHILOX_ROUND_A():
# 0xD2511F53
return -766435501
@triton.jit
def PHILOX_ROUND_B():
# 0xCD9E8D57
return -845247145
# -------------------
# randint
# -------------------
@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit
def single_round(c0, c1, c2, c3, k0, k1):
A = PHILOX_ROUND_A()
B = PHILOX_ROUND_B()
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
# raise key
k0 = k0 + PHILOX_KEY_A
k1 = k1 + PHILOX_KEY_B
return c0, c1, c2, c3
@triton.jit
def raise_key(k0, k1):
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1):
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
k0, k1 = raise_key(k0, k1)
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
return c0, c1, c2, c3
@triton.jit
def uint32_to_uniform_float(x):
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
x = tl.where(x < 0, -x - 1, x)
return x * max
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret
@triton.jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randint4x(seed, offset):
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
blocks of random :code:`int32`.
@@ -114,27 +72,26 @@ def randint4x(seed, offset):
seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32)
return philox_f(offset, z, z, z, seed_lo, seed_hi)
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
# -------------------
# rand
# -------------------
@triton.jit
def randint(seed, offset):
def uint32_to_uniform_float(x):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
block of random :code:`int32`.
If you need multiple streams of random numbers,
using `randint4x` is likely to be faster than calling `randint` 4 times.
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take.
"""
ret, _, _, _ = randint4x(seed, offset)
return ret
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
x = tl.where(x < 0, -x - 1, x)
return x * max
@triton.jit
def rand(seed, offset):
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`U(0, 1)`
@@ -142,28 +99,11 @@ def rand(seed, offset):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
source = randint(seed, offset)
source = randint(seed, offset, n_rounds)
return uint32_to_uniform_float(source)
@triton.jit
def randn(seed, offset):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def rand4x(seed, offsets):
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offsets` block,
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
@@ -171,16 +111,42 @@ def rand4x(seed, offsets):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, i3, i4 = randint4x(seed, offsets)
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
u3 = uint32_to_uniform_float(i3)
u4 = uint32_to_uniform_float(i4)
return u1, u2, u3, u4
# -------------------
# randn
# -------------------
@triton.jit
def randn4x(seed, offset):
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
th = 6.283185307179586 * u2
r = tl.sqrt(-2.0 * tl.log(u1))
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
u1 = uint32_to_uniform_float(i1)
u2 = uint32_to_uniform_float(i2)
n1, _ = pair_uniform_to_normal(u1, u2)
return n1
@triton.jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
@@ -188,7 +154,7 @@ def randn4x(seed, offset):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
u1, u2, u3, u4 = rand4x(seed, offset)
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
n1, n2 = pair_uniform_to_normal(u1, u2)
n3, n4 = pair_uniform_to_normal(u3, u4)
return n1, n2, n3, n4