[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)
This commit is contained in:
@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
|
||||
# 0xCD9E8D57
|
||||
return -845247145
|
||||
|
||||
|
||||
@triton.jit
|
||||
def hacky_to_uint64(x):
|
||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def multiply_low_high(a, b):
|
||||
return (
|
||||
a * b,
|
||||
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def single_round(c0, c1, c2, c3, k0, k1):
|
||||
A = PHILOX_ROUND_A()
|
||||
B = PHILOX_ROUND_B()
|
||||
lo0, hi0 = multiply_low_high(A, c0)
|
||||
lo1, hi1 = multiply_low_high(B, c2)
|
||||
|
||||
return (
|
||||
hi1 ^ c1 ^ k0,
|
||||
lo1,
|
||||
hi0 ^ c3 ^ k1,
|
||||
lo0,
|
||||
)
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def raise_key(k0, k1):
|
||||
return (
|
||||
k0 + PHILOX_KEY_A(),
|
||||
k1 + PHILOX_KEY_B(),
|
||||
)
|
||||
|
||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||
|
||||
@triton.jit
|
||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||
@@ -125,7 +109,7 @@ def randint4x(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
z = 0
|
||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
|
Reference in New Issue
Block a user