[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// argument is `constexpr`
|
// argument is `constexpr`
|
||||||
py::object value = arg.attr("value");
|
if(py::hasattr(arg, "value")){
|
||||||
if(value){
|
py::object value = arg.attr("value");
|
||||||
py::object name = arg_names[i];
|
py::object name = arg_names[i];
|
||||||
constants[name] = value;
|
constants[name] = value;
|
||||||
py::object repr = py::repr(value);
|
py::object repr = py::repr(value);
|
||||||
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
cache_key += std::string(start, len);
|
cache_key += std::string(start, len);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
assert(false);
|
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||||
|
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||||
|
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||||
|
throw std::runtime_error(err_msg);
|
||||||
}
|
}
|
||||||
cache_key += std::to_string(num_warps);
|
cache_key += std::to_string(num_warps);
|
||||||
cache_key += std::to_string(num_stages);
|
cache_key += std::to_string(num_stages);
|
||||||
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
CU_LAUNCH_PARAM_END
|
CU_LAUNCH_PARAM_END
|
||||||
};
|
};
|
||||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
if(grid_0*grid_1*grid_2 > 0)
|
||||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||||
nullptr, config);
|
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||||
|
nullptr, config);
|
||||||
return bin;
|
return bin;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
const std::string &args, int64_t shared_mem){
|
const std::string &args, int64_t shared_mem){
|
||||||
void* args_ptr = (void*)args.data();
|
void* args_ptr = (void*)args.data();
|
||||||
size_t args_size = args.size();
|
size_t args_size = args.size();
|
||||||
|
// release the gil in case the enqueue blocks
|
||||||
|
// cuda will block if too many ops are enqueued
|
||||||
|
Py_BEGIN_ALLOW_THREADS
|
||||||
if(backend == HOST)
|
if(backend == HOST)
|
||||||
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||||
if(backend == CUDA)
|
if(backend == CUDA)
|
||||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||||
if(backend == ROCM)
|
if(backend == ROCM)
|
||||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||||
|
Py_END_ALLOW_THREADS
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
|
|||||||
# test while
|
# test while
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# test default
|
||||||
|
# ---------------
|
||||||
|
#TODO: can't be local to test_default
|
||||||
|
@triton.jit
|
||||||
|
def _impl(value = 10):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def test_default():
|
||||||
|
value = 5
|
||||||
|
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||||
|
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _kernel(ret0, ret1, value):
|
||||||
|
tl.store(ret0, _impl())
|
||||||
|
tl.store(ret1, _impl(value))
|
||||||
|
|
||||||
|
_kernel[(1,)](ret0, ret1, value)
|
||||||
|
assert ret0.item() == 10
|
||||||
|
assert ret1.item() == value
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test noop
|
# test noop
|
||||||
#----------------
|
#----------------
|
||||||
|
@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||||
arg_names, kwarg_names = self.visit(node.args)
|
arg_names, kwarg_names = self.visit(node.args)
|
||||||
|
# initialize defaults
|
||||||
|
for i, default_value in enumerate(node.args.defaults):
|
||||||
|
arg_node = node.args.args[-i-1]
|
||||||
|
annotation = arg_node.annotation
|
||||||
|
name = arg_node.arg
|
||||||
|
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||||
|
if annotation is None:
|
||||||
|
init_node = ast.Assign(targets=[st_target], value=default_value)
|
||||||
|
else:
|
||||||
|
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||||
|
self.visit(init_node)
|
||||||
# store keyword arguments in local scope
|
# store keyword arguments in local scope
|
||||||
self.lscope[kwarg_names] = self.kwargs
|
self.lscope[kwarg_names] = self.kwargs
|
||||||
# initialize function
|
# initialize function
|
||||||
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
iterator = self.visit(node.iter.func)
|
iterator = self.visit(node.iter.func)
|
||||||
if iterator != self.builtins['range']:
|
if iterator != self.builtins['range']:
|
||||||
raise RuntimeError('Only `range` iterator currently supported')
|
raise RuntimeError('Only `range` iterator currently supported')
|
||||||
|
# static for loops: all iterator arguments are constexpr
|
||||||
|
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||||
|
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||||
|
if is_static:
|
||||||
|
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||||
|
iter_args = [arg.value for arg in iter_args]
|
||||||
|
range = iterator(*iter_args)
|
||||||
|
if len(range) <= 10:
|
||||||
|
for i in iterator(*iter_args):
|
||||||
|
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||||
|
self.visit_compound_statement(node.body)
|
||||||
|
for stmt in node.orelse:
|
||||||
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
|
return
|
||||||
# create nodes
|
# create nodes
|
||||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||||
@@ -483,6 +508,7 @@ class CompilationError(Exception):
|
|||||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||||
self.message += '\n Error: ' + str(err)
|
self.message += '\n Error: ' + str(err)
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
self.args = (src, node, err)
|
||||||
|
|
||||||
|
|
||||||
class OutOfResources(Exception):
|
class OutOfResources(Exception):
|
||||||
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
|
|||||||
f'Required: {required}'\
|
f'Required: {required}'\
|
||||||
f'Hardware limit: {limit}'
|
f'Hardware limit: {limit}'
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
self.args = (required, limit, name)
|
||||||
|
|
||||||
|
|
||||||
class Kernel:
|
class Kernel:
|
||||||
@@ -805,7 +832,10 @@ class JITFunction:
|
|||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.module = fn.__module__
|
self.module = fn.__module__
|
||||||
self.arg_names = inspect.getfullargspec(fn).args
|
signature = inspect.signature(fn)
|
||||||
|
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||||
|
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||||
|
|
||||||
self.version = version
|
self.version = version
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||||
@@ -829,7 +859,7 @@ class JITFunction:
|
|||||||
if not hasattr(self.fn, 'hash'):
|
if not hasattr(self.fn, 'hash'):
|
||||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
||||||
dependencies_finder.visit(self.parse())
|
dependencies_finder.visit(self.parse())
|
||||||
self.fn.hash = dependencies_finder.ret
|
self.fn.hash = dependencies_finder.ret + version_key()
|
||||||
return self.fn.hash
|
return self.fn.hash
|
||||||
|
|
||||||
# we do not parse `src` in the constructor because
|
# we do not parse `src` in the constructor because
|
||||||
@@ -848,6 +878,7 @@ class JITFunction:
|
|||||||
lscope = generator.lscope.copy()
|
lscope = generator.lscope.copy()
|
||||||
values = generator.module.get_values().copy()
|
values = generator.module.get_values().copy()
|
||||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||||
|
generator.lscope = dict()
|
||||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||||
generator.gscope = gscope
|
generator.gscope = gscope
|
||||||
generator.lscope = lscope
|
generator.lscope = lscope
|
||||||
|
@@ -7,98 +7,56 @@ from . import core as tl
|
|||||||
# 2. multiply_low_high is currently inefficient.
|
# 2. multiply_low_high is currently inefficient.
|
||||||
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
|
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
|
||||||
|
|
||||||
|
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||||
|
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||||
|
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||||
|
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
|
||||||
|
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||||
|
|
||||||
@triton.jit
|
# -------------------
|
||||||
def PHILOX_KEY_A():
|
# randint
|
||||||
# 0x9E3779B9
|
# -------------------
|
||||||
return -1640531527
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def PHILOX_KEY_B():
|
|
||||||
# 0xBB67AE85
|
|
||||||
return -1150833019
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def PHILOX_ROUND_A():
|
|
||||||
# 0xD2511F53
|
|
||||||
return -766435501
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def PHILOX_ROUND_B():
|
|
||||||
# 0xCD9E8D57
|
|
||||||
return -845247145
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def hacky_to_uint64(x):
|
def hacky_to_uint64(x):
|
||||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def single_round(c0, c1, c2, c3, k0, k1):
|
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
A = PHILOX_ROUND_A()
|
"""
|
||||||
B = PHILOX_ROUND_B()
|
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||||
_c0, _c2 = c0, c2
|
"""
|
||||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
for _ in range(n_rounds):
|
||||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
# update random state
|
||||||
c1 = B * _c2
|
A = PHILOX_ROUND_A
|
||||||
c3 = A * _c0
|
B = PHILOX_ROUND_B
|
||||||
|
_c0, _c2 = c0, c2
|
||||||
|
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||||
|
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||||
|
c1 = B * _c2
|
||||||
|
c3 = A * _c0
|
||||||
|
# raise key
|
||||||
|
k0 = k0 + PHILOX_KEY_A
|
||||||
|
k1 = k1 + PHILOX_KEY_B
|
||||||
return c0, c1, c2, c3
|
return c0, c1, c2, c3
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def raise_key(k0, k1):
|
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
k0, k1 = raise_key(k0, k1)
|
|
||||||
c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
|
|
||||||
return c0, c1, c2, c3
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def uint32_to_uniform_float(x):
|
|
||||||
"""
|
"""
|
||||||
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
block of random :code:`int32`.
|
||||||
covers all the possible values it can take.
|
|
||||||
|
If you need multiple streams of random numbers,
|
||||||
|
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
||||||
x = tl.where(x < 0, -x - 1, x)
|
return ret
|
||||||
return x * max
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pair_uniform_to_normal(u1, u2):
|
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""Box-Muller transform"""
|
|
||||||
u1 = tl.maximum(1.0e-7, u1)
|
|
||||||
th = 6.283185307179586 * u2
|
|
||||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
|
||||||
return r * tl.cos(th), r * tl.sin(th)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def randint4x(seed, offset):
|
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||||
blocks of random :code:`int32`.
|
blocks of random :code:`int32`.
|
||||||
@@ -114,27 +72,26 @@ def randint4x(seed, offset):
|
|||||||
seed = hacky_to_uint64(seed) # uint will solve this
|
seed = hacky_to_uint64(seed) # uint will solve this
|
||||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||||
return philox_f(offset, z, z, z, seed_lo, seed_hi)
|
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# rand
|
||||||
|
# -------------------
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randint(seed, offset):
|
def uint32_to_uniform_float(x):
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1).
|
||||||
block of random :code:`int32`.
|
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||||
|
covers all the possible values it can take.
|
||||||
If you need multiple streams of random numbers,
|
|
||||||
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
|
||||||
|
|
||||||
:param seed: The seed for generating random numbers.
|
|
||||||
:param offsets: The offsets to generate random numbers for.
|
|
||||||
"""
|
"""
|
||||||
ret, _, _, _ = randint4x(seed, offset)
|
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||||
return ret
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
|
return x * max
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def rand(seed, offset):
|
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
returns a block of random :code:`float32` in :math:`U(0, 1)`
|
returns a block of random :code:`float32` in :math:`U(0, 1)`
|
||||||
@@ -142,28 +99,11 @@ def rand(seed, offset):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
source = randint(seed, offset)
|
source = randint(seed, offset, n_rounds)
|
||||||
return uint32_to_uniform_float(source)
|
return uint32_to_uniform_float(source)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randn(seed, offset):
|
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
|
||||||
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
|
||||||
|
|
||||||
:param seed: The seed for generating random numbers.
|
|
||||||
:param offsets: The offsets to generate random numbers for.
|
|
||||||
"""
|
|
||||||
i1, i2, _, _ = randint4x(seed, offset)
|
|
||||||
u1 = uint32_to_uniform_float(i1)
|
|
||||||
u2 = uint32_to_uniform_float(i2)
|
|
||||||
n1, _ = pair_uniform_to_normal(u1, u2)
|
|
||||||
return n1
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def rand4x(seed, offsets):
|
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offsets` block,
|
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||||
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
|
returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
|
||||||
@@ -171,16 +111,42 @@ def rand4x(seed, offsets):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
i1, i2, i3, i4 = randint4x(seed, offsets)
|
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
||||||
u1 = uint32_to_uniform_float(i1)
|
u1 = uint32_to_uniform_float(i1)
|
||||||
u2 = uint32_to_uniform_float(i2)
|
u2 = uint32_to_uniform_float(i2)
|
||||||
u3 = uint32_to_uniform_float(i3)
|
u3 = uint32_to_uniform_float(i3)
|
||||||
u4 = uint32_to_uniform_float(i4)
|
u4 = uint32_to_uniform_float(i4)
|
||||||
return u1, u2, u3, u4
|
return u1, u2, u3, u4
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# randn
|
||||||
|
# -------------------
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randn4x(seed, offset):
|
def pair_uniform_to_normal(u1, u2):
|
||||||
|
"""Box-Muller transform"""
|
||||||
|
u1 = tl.maximum(1.0e-7, u1)
|
||||||
|
th = 6.283185307179586 * u2
|
||||||
|
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||||
|
return r * tl.cos(th), r * tl.sin(th)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
|
"""
|
||||||
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
|
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
|
:param seed: The seed for generating random numbers.
|
||||||
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
"""
|
||||||
|
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
|
||||||
|
u1 = uint32_to_uniform_float(i1)
|
||||||
|
u2 = uint32_to_uniform_float(i2)
|
||||||
|
n1, _ = pair_uniform_to_normal(u1, u2)
|
||||||
|
return n1
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||||
@@ -188,7 +154,7 @@ def randn4x(seed, offset):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
u1, u2, u3, u4 = rand4x(seed, offset)
|
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
|
||||||
n1, n2 = pair_uniform_to_normal(u1, u2)
|
n1, n2 = pair_uniform_to_normal(u1, u2)
|
||||||
n3, n4 = pair_uniform_to_normal(u3, u4)
|
n3, n4 = pair_uniform_to_normal(u3, u4)
|
||||||
return n1, n2, n3, n4
|
return n1, n2, n3, n4
|
||||||
|
Reference in New Issue
Block a user