diff --git a/python/tests/test_random.py b/python/tests/test_random.py new file mode 100644 index 000000000..39ae59e35 --- /dev/null +++ b/python/tests/test_random.py @@ -0,0 +1,198 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK = 1024 + +# test generation of random uint32 + + +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in ['10', '4,53', '10000'] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] + ) +def test_randint(size, seed, device='cuda'): + size = list(map(int, size.split(','))) + + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.int32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=PHILOX_32) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + +# test uniform PRNG + + +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) +def test_rand(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + +# test normal PRNG + + +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000] + for seed in [0, 42, 124, 54]] + ) +def test_randn(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + +def test_rand_limits(): + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint32_to_uniform_float(x) + tl.store(output + idx, y) + + min_max_int32 = torch.tensor([ + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + ], dtype=torch.int32, device='cuda') + output = torch.empty(2, dtype=torch.float32, device='cuda') + kernel[(1,)](min_max_int32, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/python/triton/compiler.py b/python/triton/compiler.py index fbc85eb45..25133bff1 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -625,10 +625,12 @@ class CodeGenerator(ast.NodeVisitor): if name in liveins: assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' assert self.is_triton_tensor(liveins[name]) - if self.local_defs[name].type == liveins[name].type: - names.append(name) - init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) - yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) + if self.local_defs[name].type != liveins[name].type: + local_value = self.local_defs[name] + self.local_defs[name] = local_value.to(liveins[name].dtype, _builder=self.builder) + names.append(name) + init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) + yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) # create ForOp self.builder.set_insertion_point_to_end(insert_block) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e44d706c9..b9542906c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -17,11 +17,11 @@ def _to_tensor(x, builder): if -2**31 <= x < 2**31: return tensor(builder.get_int32(x), int32) elif 2**31 <= x < 2**32: - return tensor(builder.get_uint32(x), uint32) + return tensor(builder.get_int32(x), uint32) elif -2**63 <= x < 2**63: return tensor(builder.get_int64(x), int64) elif 2**63 <= x < 2**64: - return tensor(builder.get_uint64(x), uint64) + return tensor(builder.get_int64(x), uint64) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 67de92c43..32183ec9b 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -1,10 +1,10 @@ import triton from . import core as tl -PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9 -PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85 -PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53 -PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57 +PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox # -------------------