[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)
This commit is contained in:
@@ -313,6 +313,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
@@ -340,6 +341,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
|
@@ -346,6 +346,18 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
:param input: The first input block.
|
||||
:type input:
|
||||
:param other: The second input block.
|
||||
:type other:
|
||||
"""
|
||||
return frontend.cat(input, other, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
@@ -524,6 +536,10 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -543,7 +559,6 @@ def _add_math_1arg_docstr(name):
|
||||
def exp(x, _builder=None):
|
||||
return frontend.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
|
@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
|
||||
# 0xCD9E8D57
|
||||
return -845247145
|
||||
|
||||
|
||||
@triton.jit
|
||||
def hacky_to_uint64(x):
|
||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def multiply_low_high(a, b):
|
||||
return (
|
||||
a * b,
|
||||
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def single_round(c0, c1, c2, c3, k0, k1):
|
||||
A = PHILOX_ROUND_A()
|
||||
B = PHILOX_ROUND_B()
|
||||
lo0, hi0 = multiply_low_high(A, c0)
|
||||
lo1, hi1 = multiply_low_high(B, c2)
|
||||
|
||||
return (
|
||||
hi1 ^ c1 ^ k0,
|
||||
lo1,
|
||||
hi0 ^ c3 ^ k1,
|
||||
lo0,
|
||||
)
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def raise_key(k0, k1):
|
||||
return (
|
||||
k0 + PHILOX_KEY_A(),
|
||||
k1 + PHILOX_KEY_B(),
|
||||
)
|
||||
|
||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||
|
||||
@triton.jit
|
||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||
@@ -125,7 +109,7 @@ def randint4x(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
z = 0
|
||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
|
Reference in New Issue
Block a user