diff --git a/docs/conf.py b/docs/conf.py index 1107ef171..67a14f47a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,7 +66,7 @@ def setup(app): import sys import os sys.path.insert(0, os.path.abspath('../python/')) -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] autosummary_generate = True # Sphinx gallery @@ -78,6 +78,9 @@ sphinx_gallery_conf = { 'filename_pattern': '', 'ignore_pattern': r'__init__\.py', 'within_subsection_order': FileNameSortKey, + 'reference_url': { + 'sphinx_gallery': None, + } } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/getting-started/tutorials/random_bits.png b/docs/getting-started/tutorials/random_bits.png new file mode 100644 index 000000000..198f90a5e Binary files /dev/null and b/docs/getting-started/tutorials/random_bits.png differ diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 4cb437faf..1f05ce8a6 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -121,6 +121,19 @@ Comparison ops minimum maximum +.. _Random Number Generation: + +Random Number Generation +------------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + randint4x + randint + rand + randn Compiler Hint Ops ------------------- @@ -129,4 +142,4 @@ Compiler Hint Ops :toctree: generated :nosignatures: - multiple_of + multiple_of \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 308ffa966..2965f167b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -126,7 +126,7 @@ setup( author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", - packages=["triton", "triton/_C", "triton/tools", "triton/ops", "triton/ops/blocksparse"], + packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], install_requires=["torch"], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, diff --git a/python/test/test_language.py b/python/test/language/test_core.py similarity index 100% rename from python/test/test_language.py rename to python/test/language/test_core.py diff --git a/python/test/language/test_random.py b/python/test/language/test_random.py new file mode 100644 index 000000000..6c15a7588 --- /dev/null +++ b/python/test/language/test_random.py @@ -0,0 +1,198 @@ +import torch +import triton +import triton.language as tl +import pytest +import scipy.stats +import numpy as np + +from numpy.random import Philox + +##################################### +## Reference Philox Implementation +##################################### + +class PhiloxConfig: + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + ret0 = key[0] + self._config.PHILOX_KEY_A + ret1 = key[1] + self._config.PHILOX_KEY_B + return np.array([ret0, ret1], dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +## Unit Tests +##################################### + +BLOCK = 1024 + +# test generation of random uint32 +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in ['10', '4,53', '10000']\ + for seed in [0, 42, 124, 54]] +) +def test_randint(size, seed, device='cuda'): + size = list(map(int, size.split(','))) + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.int32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=PHILOX_32) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + +# test conversion of random uint32 into random float in [0, 1] +def test_uint32_to_uniform_float(): + @triton.jit + def kernel(SRC, TGT, N, **meta): + pid = tl.program_id(0) + offset = pid * BLOCK + tl.arange(0, BLOCK) + src = tl.load(SRC + offset) + tgt = tl.random.uint32_to_uniform_float(src) + tl.store(TGT + offset, tgt, mask=offset < N) + + def run(source): + target = -torch.ones(source.shape, dtype=torch.float32, device=source.device) + N = source.numel() + grid = lambda meta: (triton.cdiv(N, BLOCK),) + kernel[grid](source, target, N) + return target + + # check range of edge values + n = 100 + source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda() + target = run(source).tolist() + assert target == sorted(target) + assert all(0.0 <= num < 1.0 for num in target) + # check distribution is uniform + source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda() + target = run(source).tolist() + assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01 + +# test uniform PRNG +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000]\ + for seed in [0, 42, 124, 54]] +) +def test_rand(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + +# test normal PRNG +@pytest.mark.parametrize('size, seed', + [(size, seed) for size in [1000000]\ + for seed in [0, 42, 124, 54]] +) +def test_randn(size, seed, device='cuda'): + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK),) + kernel[grid](x, N, seed) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 diff --git a/python/test/test_blocksparse.py b/python/test/operators/test_blocksparse.py similarity index 100% rename from python/test/test_blocksparse.py rename to python/test/operators/test_blocksparse.py diff --git a/python/test/test_cross_entropy.py b/python/test/operators/test_cross_entropy.py similarity index 100% rename from python/test/test_cross_entropy.py rename to python/test/operators/test_cross_entropy.py diff --git a/python/test/test_matmul.py b/python/test/operators/test_matmul.py similarity index 100% rename from python/test/test_matmul.py rename to python/test/operators/test_matmul.py diff --git a/python/test/test_comm.py b/python/test/runtime/test_comm.py similarity index 100% rename from python/test/test_comm.py rename to python/test/runtime/test_comm.py diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 7694b9ec9..3f08b5133 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -9,4 +9,4 @@ from . import code_gen from . import testing from . import ops # version -__version__ = '1.0.0' \ No newline at end of file +__version__ = '1.0.0' diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index ab1daeb41..b96260c51 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,2 +1,4 @@ from . import core -from .core import * \ No newline at end of file +from . import random +from .core import * +from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ff243bc5e..22cd717e7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -648,6 +648,7 @@ def cdiv(x, div): """ return (x + div - 1) // div + @triton.jit def minimum(x, y): """ diff --git a/python/triton/language/random.py b/python/triton/language/random.py new file mode 100644 index 000000000..913073679 --- /dev/null +++ b/python/triton/language/random.py @@ -0,0 +1,208 @@ +import triton +import triton.language as tl + + +# Notes +# 1. triton doesn't support uint32, so we use int32 instead and benefit from the fact that two's complement operations are equivalent to uint operations. +# 2. multiply_low_high is currently inefficient. +# 3. Even though technically philox sampling outputs int, in many places we pretends they were actualy uints e.g. uint_to_uniform_float + + +@triton.jit +def PHILOX_KEY_A(): + # 0x9E3779B9 + return -1640531527 + + +@triton.jit +def PHILOX_KEY_B(): + # 0xBB67AE85 + return -1150833019 + + +@triton.jit +def PHILOX_ROUND_A(): + # 0xD2511F53 + return -766435501 + + +@triton.jit +def PHILOX_ROUND_B(): + # 0xCD9E8D57 + return -845247145 + + +@triton.jit +def hacky_to_uint64(x): + return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) + + +@triton.jit +def multiply_low_high(a, b): + return ( + a * b, + ((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32) + ) + + +@triton.jit +def single_round(c0, c1, c2, c3, k0, k1): + A = PHILOX_ROUND_A() + B = PHILOX_ROUND_B() + lo0, hi0 = multiply_low_high(A, c0) + lo1, hi1 = multiply_low_high(B, c2) + + return ( + hi1 ^ c1 ^ k0, + lo1, + hi0 ^ c3 ^ k1, + lo0, + ) + + +@triton.jit +def raise_key(k0, k1): + return ( + k0 + PHILOX_KEY_A(), + k1 + PHILOX_KEY_B(), + ) + + +@triton.jit +def philox_f(c0, c1, c2, c3, k0, k1): + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + k0, k1 = raise_key(k0, k1) + c0, c1, c2, c3 = single_round(c0, c1, c2, c3, k0, k1) + return c0, c1, c2, c3 + + + +@triton.jit +def uint32_to_uniform_float(x): + """ + Numerically stable function to convert a random integer into a random float uniformly sampled in [0, 1). + This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly + covers all the possible values it can take. + """ + mantissa = x & 0x7fffff + exp = 127 + res = mantissa | (exp << 23) + return res.to(tl.float32, bitcast=True) - 1.0 + + +@triton.jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = tl.sqrt(-2.0 * tl.log(u1)) + return r * tl.cos(th), r * tl.sin(th) + + +@triton.jit +def randint4x(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + z = 0 + return philox_f(offset, z, z, z, seed, z) + + +@triton.jit +def randint(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset) + return ret + + +@triton.jit +def rand(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset) + return uint32_to_uniform_float(source) + + +@triton.jit +def randn(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@triton.jit +def rand4x(seed, offsets): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + u3 = uint32_to_uniform_float(i3) + u4 = uint32_to_uniform_float(i4) + return u1, u2, u3, u4 + + +@triton.jit +def randn4x(seed, offset): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 4baa951c1..e0847ae86 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -43,7 +43,7 @@ def add_kernel( y = tl.load(y_ptr + offsets, mask=mask) output = x + y # Write x + y back to DRAM - tl.store(output_ptr + offsets, output) + tl.store(output_ptr + offsets, output, mask=mask) # %% diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py new file mode 100644 index 000000000..d988746a7 --- /dev/null +++ b/python/tutorials/04-low-memory-dropout.py @@ -0,0 +1,164 @@ +""" +Low-Memory Dropout +================= + +In this tutorial, you will write a memory-efficient implementation of dropout whose state +will be composed of a single int32 seed. This differs from more traditional implementations of dropout, +whose state is generally composed of a bit mask tensor of the same shape as the input. You will learn about: + +- The limitations of naive implementations of Dropout with PyTorch +- Parallel pseudo-random number generation in Triton +""" + +# %% +# Baseline +# ------------- +# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance +# of deep neural networks in low-data regime (i.e. regularization). +# +# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the +# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input. +# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available. +# +# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would +# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease +# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which +# keeps the norm consistent regardless of the dropout probability. +# +# Let's first take a look at the baseline implementation. + + +import tabulate +import torch +import triton +import triton.language as tl + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + **meta, +): + BLOCK_SIZE = meta['BLOCK_SIZE'] + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + +# Input tensor +x = torch.randn(size=(10,)).cuda() +# Dropout mask +p = 0.5 +x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() +# +output = dropout(x, x_keep=x_keep, p=p) +print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["keep mask"] + x_keep.tolist(), + ["output"] + output.tolist() +])) + +# %% +# Seeded dropout +# ------------- +# Above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly +# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get +# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in +# https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation +# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management +# of persisting randomness across multiple invocations of the kernel. +# +# Pseudorandom number generation in Triton is simple! In this tutorial we will use the +# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` +# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides +# other :ref:`random number generation strategies `. +# +# .. note:: +# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_). +# +# Let's put it all together. + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + **meta, +): + # compute memory offsets of elements handled by this instance + BLOCK_SIZE = meta['BLOCK_SIZE'] + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +x = torch.randn(size=(10,)).cuda() +# Compare this to the baseline - dropout mask is never instantiated! +output = seeded_dropout(x, p=0.5, seed=123) +output2 = seeded_dropout(x, p=0.5, seed=123) +output3 = seeded_dropout(x, p=0.5, seed=512) + +print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist() +])) + +# %% +# Et VoilĂ ! We have a triton kernel that applies the same dropout mask provided the seed is the same! +# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you +# to explore the `triton/language/random` folder! + +# %% +# Exercises +# ------------- +# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row. +# 2. Add support for striding. +# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix one the fly each time using a seed. + +# %% +# References +# -------------- +# +# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011 +# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014