[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
		@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
 | 
				
			|||||||
        continue;
 | 
					        continue;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      // argument is `constexpr`
 | 
					      // argument is `constexpr`
 | 
				
			||||||
      py::object value = arg.attr("value");
 | 
					      if(py::hasattr(arg, "value")){
 | 
				
			||||||
      if(value){
 | 
					        py::object value = arg.attr("value");
 | 
				
			||||||
        py::object name = arg_names[i];
 | 
					        py::object name = arg_names[i];
 | 
				
			||||||
        constants[name] = value;
 | 
					        constants[name] = value;
 | 
				
			||||||
        py::object repr = py::repr(value);
 | 
					        py::object repr = py::repr(value);
 | 
				
			||||||
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
 | 
				
			|||||||
        cache_key += std::string(start, len);
 | 
					        cache_key += std::string(start, len);
 | 
				
			||||||
        continue;
 | 
					        continue;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      assert(false);
 | 
					      std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
 | 
				
			||||||
 | 
					      std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
 | 
				
			||||||
 | 
					                            + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
 | 
				
			||||||
 | 
					      throw std::runtime_error(err_msg);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  cache_key += std::to_string(num_warps);
 | 
					  cache_key += std::to_string(num_warps);
 | 
				
			||||||
  cache_key += std::to_string(num_stages);
 | 
					  cache_key += std::to_string(num_stages);
 | 
				
			||||||
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
 | 
				
			|||||||
        CU_LAUNCH_PARAM_END
 | 
					        CU_LAUNCH_PARAM_END
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    uint64_t _stream = PyLong_AsLong(stream.ptr());
 | 
					    uint64_t _stream = PyLong_AsLong(stream.ptr());
 | 
				
			||||||
    drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, 
 | 
					    if(grid_0*grid_1*grid_2 > 0)
 | 
				
			||||||
                                  _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, 
 | 
					      drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, 
 | 
				
			||||||
                                  nullptr, config);
 | 
					                                    _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, 
 | 
				
			||||||
 | 
					                                     nullptr, config);
 | 
				
			||||||
    return bin;
 | 
					    return bin;
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
 | 
				
			|||||||
                      const std::string &args, int64_t shared_mem){
 | 
					                      const std::string &args, int64_t shared_mem){
 | 
				
			||||||
    void* args_ptr = (void*)args.data();
 | 
					    void* args_ptr = (void*)args.data();
 | 
				
			||||||
    size_t args_size = args.size();
 | 
					    size_t args_size = args.size();
 | 
				
			||||||
 | 
					    // release the gil in case the enqueue blocks
 | 
				
			||||||
 | 
					    // cuda will block if too many ops are enqueued
 | 
				
			||||||
 | 
					    Py_BEGIN_ALLOW_THREADS
 | 
				
			||||||
    if(backend == HOST)
 | 
					    if(backend == HOST)
 | 
				
			||||||
      host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
					      host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
				
			||||||
    if(backend == CUDA)
 | 
					    if(backend == CUDA)
 | 
				
			||||||
      cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
					      cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
				
			||||||
    if(backend == ROCM)
 | 
					    if(backend == ROCM)
 | 
				
			||||||
      hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
					      hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
 | 
				
			||||||
 | 
					    Py_END_ALLOW_THREADS
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  
 | 
					  
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -634,6 +634,28 @@ def test_load_cache_modifier(cache):
 | 
				
			|||||||
# test while
 | 
					# test while
 | 
				
			||||||
# ---------------
 | 
					# ---------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ---------------
 | 
				
			||||||
 | 
					# test default
 | 
				
			||||||
 | 
					# ---------------
 | 
				
			||||||
 | 
					#TODO: can't be local to test_default
 | 
				
			||||||
 | 
					@triton.jit
 | 
				
			||||||
 | 
					def _impl(value = 10):
 | 
				
			||||||
 | 
					    return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_default():
 | 
				
			||||||
 | 
					    value = 5
 | 
				
			||||||
 | 
					    ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
 | 
				
			||||||
 | 
					    ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @triton.jit
 | 
				
			||||||
 | 
					    def _kernel(ret0, ret1, value):
 | 
				
			||||||
 | 
					        tl.store(ret0, _impl())
 | 
				
			||||||
 | 
					        tl.store(ret1, _impl(value))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    _kernel[(1,)](ret0, ret1, value)
 | 
				
			||||||
 | 
					    assert ret0.item() == 10
 | 
				
			||||||
 | 
					    assert ret1.item() == value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ---------------
 | 
					# ---------------
 | 
				
			||||||
# test noop
 | 
					# test noop
 | 
				
			||||||
#----------------
 | 
					#----------------
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def visit_FunctionDef(self, node, inline=False, arg_values=None):
 | 
					    def visit_FunctionDef(self, node, inline=False, arg_values=None):
 | 
				
			||||||
        arg_names, kwarg_names = self.visit(node.args)
 | 
					        arg_names, kwarg_names = self.visit(node.args)
 | 
				
			||||||
 | 
					        # initialize defaults
 | 
				
			||||||
 | 
					        for i, default_value in enumerate(node.args.defaults):
 | 
				
			||||||
 | 
					            arg_node = node.args.args[-i-1]
 | 
				
			||||||
 | 
					            annotation = arg_node.annotation
 | 
				
			||||||
 | 
					            name = arg_node.arg
 | 
				
			||||||
 | 
					            st_target = ast.Name(id=name, ctx=ast.Store())
 | 
				
			||||||
 | 
					            if annotation is None:
 | 
				
			||||||
 | 
					                init_node = ast.Assign(targets=[st_target], value=default_value)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
 | 
				
			||||||
 | 
					            self.visit(init_node)
 | 
				
			||||||
        # store keyword arguments in local scope
 | 
					        # store keyword arguments in local scope
 | 
				
			||||||
        self.lscope[kwarg_names] = self.kwargs
 | 
					        self.lscope[kwarg_names] = self.kwargs
 | 
				
			||||||
        # initialize function
 | 
					        # initialize function
 | 
				
			||||||
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
 | 
				
			|||||||
        iterator = self.visit(node.iter.func)
 | 
					        iterator = self.visit(node.iter.func)
 | 
				
			||||||
        if iterator != self.builtins['range']:
 | 
					        if iterator != self.builtins['range']:
 | 
				
			||||||
            raise RuntimeError('Only `range` iterator currently supported')
 | 
					            raise RuntimeError('Only `range` iterator currently supported')
 | 
				
			||||||
 | 
					        # static for loops: all iterator arguments are constexpr
 | 
				
			||||||
 | 
					        iter_args = [self.visit(arg) for arg in node.iter.args]
 | 
				
			||||||
 | 
					        is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
 | 
				
			||||||
 | 
					        if is_static:
 | 
				
			||||||
 | 
					            st_target = ast.Name(id=node.target.id, ctx=ast.Store())
 | 
				
			||||||
 | 
					            iter_args = [arg.value for arg in iter_args]
 | 
				
			||||||
 | 
					            range = iterator(*iter_args)
 | 
				
			||||||
 | 
					            if len(range) <= 10:
 | 
				
			||||||
 | 
					                for i in iterator(*iter_args):
 | 
				
			||||||
 | 
					                    self.lscope[node.target.id] = triton.language.constexpr(i)
 | 
				
			||||||
 | 
					                    self.visit_compound_statement(node.body)
 | 
				
			||||||
 | 
					                    for stmt in node.orelse:
 | 
				
			||||||
 | 
					                        ast.NodeVisitor.generic_visit(self, stmt)
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
        # create nodes
 | 
					        # create nodes
 | 
				
			||||||
        st_target = ast.Name(id=node.target.id, ctx=ast.Store())
 | 
					        st_target = ast.Name(id=node.target.id, ctx=ast.Store())
 | 
				
			||||||
        ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
 | 
					        ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
 | 
				
			||||||
@@ -483,6 +508,7 @@ class CompilationError(Exception):
 | 
				
			|||||||
        self.message += '\n' + ' ' * node.col_offset + '^'
 | 
					        self.message += '\n' + ' ' * node.col_offset + '^'
 | 
				
			||||||
        self.message += '\n Error: ' + str(err)
 | 
					        self.message += '\n Error: ' + str(err)
 | 
				
			||||||
        super().__init__(self.message)
 | 
					        super().__init__(self.message)
 | 
				
			||||||
 | 
					        self.args = (src, node, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OutOfResources(Exception):
 | 
					class OutOfResources(Exception):
 | 
				
			||||||
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
 | 
				
			|||||||
                       f'Required: {required}'\
 | 
					                       f'Required: {required}'\
 | 
				
			||||||
                       f'Hardware limit: {limit}'
 | 
					                       f'Hardware limit: {limit}'
 | 
				
			||||||
        super().__init__(self.message)
 | 
					        super().__init__(self.message)
 | 
				
			||||||
 | 
					        self.args = (required, limit, name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Kernel:
 | 
					class Kernel:
 | 
				
			||||||
@@ -805,7 +832,10 @@ class JITFunction:
 | 
				
			|||||||
        # information of wrapped function
 | 
					        # information of wrapped function
 | 
				
			||||||
        self.fn = fn
 | 
					        self.fn = fn
 | 
				
			||||||
        self.module = fn.__module__
 | 
					        self.module = fn.__module__
 | 
				
			||||||
        self.arg_names = inspect.getfullargspec(fn).args
 | 
					        signature = inspect.signature(fn)
 | 
				
			||||||
 | 
					        self.arg_names = [v.name for v in signature.parameters.values()]
 | 
				
			||||||
 | 
					        self.arg_defaults = [v.default for v in signature.parameters.values()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.version = version
 | 
					        self.version = version
 | 
				
			||||||
        self.src = textwrap.dedent(inspect.getsource(fn))
 | 
					        self.src = textwrap.dedent(inspect.getsource(fn))
 | 
				
			||||||
        self.do_not_specialize = [] if do_not_specialize is None else\
 | 
					        self.do_not_specialize = [] if do_not_specialize is None else\
 | 
				
			||||||
@@ -829,7 +859,7 @@ class JITFunction:
 | 
				
			|||||||
        if not hasattr(self.fn, 'hash'):
 | 
					        if not hasattr(self.fn, 'hash'):
 | 
				
			||||||
            dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
 | 
					            dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
 | 
				
			||||||
            dependencies_finder.visit(self.parse())
 | 
					            dependencies_finder.visit(self.parse())
 | 
				
			||||||
            self.fn.hash = dependencies_finder.ret
 | 
					            self.fn.hash = dependencies_finder.ret + version_key()
 | 
				
			||||||
        return self.fn.hash
 | 
					        return self.fn.hash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # we do not parse `src` in the constructor because
 | 
					    # we do not parse `src` in the constructor because
 | 
				
			||||||
@@ -848,6 +878,7 @@ class JITFunction:
 | 
				
			|||||||
            lscope = generator.lscope.copy()
 | 
					            lscope = generator.lscope.copy()
 | 
				
			||||||
            values = generator.module.get_values().copy()
 | 
					            values = generator.module.get_values().copy()
 | 
				
			||||||
            generator.gscope = sys.modules[self.fn.__module__].__dict__
 | 
					            generator.gscope = sys.modules[self.fn.__module__].__dict__
 | 
				
			||||||
 | 
					            generator.lscope = dict()
 | 
				
			||||||
            ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
 | 
					            ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
 | 
				
			||||||
            generator.gscope = gscope
 | 
					            generator.gscope = gscope
 | 
				
			||||||
            generator.lscope = lscope
 | 
					            generator.lscope = lscope
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,98 +7,56 @@ from . import core as tl
 | 
				
			|||||||
# 2. multiply_low_high is currently inefficient.
 | 
					# 2. multiply_low_high is currently inefficient.
 | 
				
			||||||
# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
 | 
					# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
 | 
				
			||||||
 | 
					PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
 | 
				
			||||||
 | 
					PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
 | 
				
			||||||
 | 
					PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
 | 
				
			||||||
 | 
					N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					# -------------------
 | 
				
			||||||
def PHILOX_KEY_A():
 | 
					# randint
 | 
				
			||||||
    # 0x9E3779B9
 | 
					# -------------------
 | 
				
			||||||
    return -1640531527
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def PHILOX_KEY_B():
 | 
					 | 
				
			||||||
    # 0xBB67AE85
 | 
					 | 
				
			||||||
    return -1150833019
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def PHILOX_ROUND_A():
 | 
					 | 
				
			||||||
    # 0xD2511F53
 | 
					 | 
				
			||||||
    return -766435501
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def PHILOX_ROUND_B():
 | 
					 | 
				
			||||||
    # 0xCD9E8D57
 | 
					 | 
				
			||||||
    return -845247145
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def hacky_to_uint64(x):
 | 
					def hacky_to_uint64(x):
 | 
				
			||||||
    return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
 | 
					    return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def single_round(c0, c1, c2, c3, k0, k1):
 | 
					def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    A = PHILOX_ROUND_A()
 | 
					    """
 | 
				
			||||||
    B = PHILOX_ROUND_B()
 | 
					    Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
 | 
				
			||||||
    _c0, _c2 = c0, c2
 | 
					    """
 | 
				
			||||||
    c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
 | 
					    for _ in range(n_rounds):
 | 
				
			||||||
    c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
 | 
					        # update random state
 | 
				
			||||||
    c1 = B * _c2
 | 
					        A = PHILOX_ROUND_A
 | 
				
			||||||
    c3 = A * _c0
 | 
					        B = PHILOX_ROUND_B
 | 
				
			||||||
 | 
					        _c0, _c2 = c0, c2
 | 
				
			||||||
 | 
					        c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
 | 
				
			||||||
 | 
					        c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
 | 
				
			||||||
 | 
					        c1 = B * _c2
 | 
				
			||||||
 | 
					        c3 = A * _c0
 | 
				
			||||||
 | 
					        # raise key
 | 
				
			||||||
 | 
					        k0 = k0 + PHILOX_KEY_A
 | 
				
			||||||
 | 
					        k1 = k1 + PHILOX_KEY_B
 | 
				
			||||||
    return c0, c1, c2, c3
 | 
					    return c0, c1, c2, c3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def raise_key(k0, k1):
 | 
					def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def philox_f(c0, c1, c2, c3, k0, k1):
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    k0, k1 = raise_key(k0, k1)
 | 
					 | 
				
			||||||
    c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1)
 | 
					 | 
				
			||||||
    return c0, c1, c2, c3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def uint32_to_uniform_float(x):
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). 
 | 
					    Given a :code:`seed` scalar and an :code:`offset` block, returns a single 
 | 
				
			||||||
    This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly 
 | 
					    block of random :code:`int32`. 
 | 
				
			||||||
    covers all the possible values it can take.
 | 
					    
 | 
				
			||||||
 | 
					    If you need multiple streams of random numbers,
 | 
				
			||||||
 | 
					    using `randint4x` is likely to be faster than calling `randint` 4 times.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param seed: The seed for generating random numbers.
 | 
				
			||||||
 | 
					    :param offsets: The offsets to generate random numbers for.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
 | 
					    ret, _, _, _ = randint4x(seed, offset, n_rounds)
 | 
				
			||||||
    x = tl.where(x < 0, -x - 1, x)
 | 
					    return ret
 | 
				
			||||||
    return x * max
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def pair_uniform_to_normal(u1, u2):
 | 
					def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    """Box-Muller transform"""
 | 
					 | 
				
			||||||
    u1 = tl.maximum(1.0e-7, u1)
 | 
					 | 
				
			||||||
    th = 6.283185307179586 * u2
 | 
					 | 
				
			||||||
    r = tl.sqrt(-2.0 * tl.log(u1))
 | 
					 | 
				
			||||||
    return r * tl.cos(th), r * tl.sin(th)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def randint4x(seed, offset):
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offset` block, returns four 
 | 
					    Given a :code:`seed` scalar and an :code:`offset` block, returns four 
 | 
				
			||||||
    blocks of random :code:`int32`. 
 | 
					    blocks of random :code:`int32`. 
 | 
				
			||||||
@@ -114,27 +72,26 @@ def randint4x(seed, offset):
 | 
				
			|||||||
    seed = hacky_to_uint64(seed) # uint will solve this
 | 
					    seed = hacky_to_uint64(seed) # uint will solve this
 | 
				
			||||||
    seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
 | 
					    seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
 | 
				
			||||||
    seed_lo = (seed & 0xffffffff).to(tl.int32)
 | 
					    seed_lo = (seed & 0xffffffff).to(tl.int32)
 | 
				
			||||||
    return philox_f(offset, z, z, z, seed_lo, seed_hi)
 | 
					    return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -------------------
 | 
				
			||||||
 | 
					# rand
 | 
				
			||||||
 | 
					# -------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def randint(seed, offset):
 | 
					def uint32_to_uniform_float(x):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offset` block, returns a single 
 | 
					    Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). 
 | 
				
			||||||
    block of random :code:`int32`. 
 | 
					    This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly 
 | 
				
			||||||
    
 | 
					    covers all the possible values it can take.
 | 
				
			||||||
    If you need multiple streams of random numbers,
 | 
					 | 
				
			||||||
    using `randint4x` is likely to be faster than calling `randint` 4 times.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param seed: The seed for generating random numbers.
 | 
					 | 
				
			||||||
    :param offsets: The offsets to generate random numbers for.
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ret, _, _, _ = randint4x(seed, offset)
 | 
					    max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
 | 
				
			||||||
    return ret
 | 
					    x = tl.where(x < 0, -x - 1, x)
 | 
				
			||||||
 | 
					    return x * max
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def rand(seed, offset):
 | 
					def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offset` block, 
 | 
					    Given a :code:`seed` scalar and an :code:`offset` block, 
 | 
				
			||||||
    returns a block of random :code:`float32` in :math:`U(0, 1)`
 | 
					    returns a block of random :code:`float32` in :math:`U(0, 1)`
 | 
				
			||||||
@@ -142,28 +99,11 @@ def rand(seed, offset):
 | 
				
			|||||||
    :param seed: The seed for generating random numbers.
 | 
					    :param seed: The seed for generating random numbers.
 | 
				
			||||||
    :param offsets: The offsets to generate random numbers for.
 | 
					    :param offsets: The offsets to generate random numbers for.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    source = randint(seed, offset)
 | 
					    source = randint(seed, offset, n_rounds)
 | 
				
			||||||
    return uint32_to_uniform_float(source)
 | 
					    return uint32_to_uniform_float(source)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def randn(seed, offset):
 | 
					def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offset` block, 
 | 
					 | 
				
			||||||
    returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param seed: The seed for generating random numbers.
 | 
					 | 
				
			||||||
    :param offsets: The offsets to generate random numbers for.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    i1, i2, _, _ = randint4x(seed, offset)
 | 
					 | 
				
			||||||
    u1 = uint32_to_uniform_float(i1)
 | 
					 | 
				
			||||||
    u2 = uint32_to_uniform_float(i2)
 | 
					 | 
				
			||||||
    n1, _ = pair_uniform_to_normal(u1, u2)
 | 
					 | 
				
			||||||
    return n1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@triton.jit
 | 
					 | 
				
			||||||
def rand4x(seed, offsets):
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offsets` block,
 | 
					    Given a :code:`seed` scalar and an :code:`offsets` block,
 | 
				
			||||||
    returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
 | 
					    returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`
 | 
				
			||||||
@@ -171,16 +111,42 @@ def rand4x(seed, offsets):
 | 
				
			|||||||
    :param seed: The seed for generating random numbers.
 | 
					    :param seed: The seed for generating random numbers.
 | 
				
			||||||
    :param offsets: The offsets to generate random numbers for.
 | 
					    :param offsets: The offsets to generate random numbers for.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    i1, i2, i3, i4 = randint4x(seed, offsets)
 | 
					    i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
 | 
				
			||||||
    u1 = uint32_to_uniform_float(i1)
 | 
					    u1 = uint32_to_uniform_float(i1)
 | 
				
			||||||
    u2 = uint32_to_uniform_float(i2)
 | 
					    u2 = uint32_to_uniform_float(i2)
 | 
				
			||||||
    u3 = uint32_to_uniform_float(i3)
 | 
					    u3 = uint32_to_uniform_float(i3)
 | 
				
			||||||
    u4 = uint32_to_uniform_float(i4)
 | 
					    u4 = uint32_to_uniform_float(i4)
 | 
				
			||||||
    return u1, u2, u3, u4
 | 
					    return u1, u2, u3, u4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -------------------
 | 
				
			||||||
 | 
					# randn
 | 
				
			||||||
 | 
					# -------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@triton.jit
 | 
					@triton.jit
 | 
				
			||||||
def randn4x(seed, offset):
 | 
					def pair_uniform_to_normal(u1, u2):
 | 
				
			||||||
 | 
					    """Box-Muller transform"""
 | 
				
			||||||
 | 
					    u1 = tl.maximum(1.0e-7, u1)
 | 
				
			||||||
 | 
					    th = 6.283185307179586 * u2
 | 
				
			||||||
 | 
					    r = tl.sqrt(-2.0 * tl.log(u1))
 | 
				
			||||||
 | 
					    return r * tl.cos(th), r * tl.sin(th)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@triton.jit
 | 
				
			||||||
 | 
					def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Given a :code:`seed` scalar and an :code:`offset` block, 
 | 
				
			||||||
 | 
					    returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param seed: The seed for generating random numbers.
 | 
				
			||||||
 | 
					    :param offsets: The offsets to generate random numbers for.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    i1, i2, _, _ = randint4x(seed, offset, n_rounds)
 | 
				
			||||||
 | 
					    u1 = uint32_to_uniform_float(i1)
 | 
				
			||||||
 | 
					    u2 = uint32_to_uniform_float(i2)
 | 
				
			||||||
 | 
					    n1, _ = pair_uniform_to_normal(u1, u2)
 | 
				
			||||||
 | 
					    return n1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@triton.jit
 | 
				
			||||||
 | 
					def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given a :code:`seed` scalar and an :code:`offset` block,
 | 
					    Given a :code:`seed` scalar and an :code:`offset` block,
 | 
				
			||||||
    returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
 | 
					    returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
 | 
				
			||||||
@@ -188,7 +154,7 @@ def randn4x(seed, offset):
 | 
				
			|||||||
    :param seed: The seed for generating random numbers.
 | 
					    :param seed: The seed for generating random numbers.
 | 
				
			||||||
    :param offsets: The offsets to generate random numbers for.
 | 
					    :param offsets: The offsets to generate random numbers for.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    u1, u2, u3, u4 = rand4x(seed, offset)
 | 
					    u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
 | 
				
			||||||
    n1, n2 = pair_uniform_to_normal(u1, u2)
 | 
					    n1, n2 = pair_uniform_to_normal(u1, u2)
 | 
				
			||||||
    n3, n4 = pair_uniform_to_normal(u3, u4)
 | 
					    n3, n4 = pair_uniform_to_normal(u3, u4)
 | 
				
			||||||
    return n1, n2, n3, n4
 | 
					    return n1, n2, n3, n4
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user