[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)

This commit is contained in:
Philippe Tillet
2021-10-24 02:30:46 -07:00
committed by GitHub
parent 858dec8372
commit 5ce1b726dc
17 changed files with 149 additions and 60 deletions

View File

@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
# 0xCD9E8D57
return -845247145
@triton.jit
def hacky_to_uint64(x):
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
@triton.jit
def multiply_low_high(a, b):
return (
a * b,
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
)
@triton.jit
def single_round(c0, c1, c2, c3, k0, k1):
A = PHILOX_ROUND_A()
B = PHILOX_ROUND_B()
lo0, hi0 = multiply_low_high(A, c0)
lo1, hi1 = multiply_low_high(B, c2)
return (
hi1 ^ c1 ^ k0,
lo1,
hi0 ^ c3 ^ k1,
lo0,
)
_c0, _c2 = c0, c2
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
c1 = B * _c2
c3 = A * _c0
return c0, c1, c2, c3
@triton.jit
def raise_key(k0, k1):
return (
k0 + PHILOX_KEY_A(),
k1 + PHILOX_KEY_B(),
)
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1):
@@ -125,7 +109,7 @@ def randint4x(seed, offset):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
z = 0
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
seed = hacky_to_uint64(seed) # uint will solve this
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
seed_lo = (seed & 0xffffffff).to(tl.int32)