diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 695dfd1e3..d4ba42733 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -25,7 +25,7 @@ jobs: run: | alias python='python3' cd python - pip3 install -e . + pip3 install -e '.[tutorials]' - name: Build docs run: | @@ -39,4 +39,4 @@ jobs: eval `ssh-agent -s` DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }} git remote set-url origin git@github.com:openai/triton.git - git push \ No newline at end of file + git push diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 987b346a3..d99e95dc7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -28,7 +28,7 @@ jobs: run: | alias python='python3' cd python - pip3 install -e . + pip3 install -e '.[tests]' - name: Unit tests run: | @@ -44,4 +44,3 @@ jobs: pytest -vs . sudo nvidia-smi -i 0 -rgc sudo nvidia-smi -i 0 -rmc - diff --git a/docs/getting-started/installation.rst b/docs/getting-started/installation.rst index db6b6261b..20c4628bc 100644 --- a/docs/getting-started/installation.rst +++ b/docs/getting-started/installation.rst @@ -44,7 +44,7 @@ You can then test your installation by running the unit tests: .. code-block:: bash - pip install -r requirements-test.txt + pip install -e '.[tests]' pytest -vs test/unit/ and the benchmarks diff --git a/python/requirements-test.txt b/python/requirements-test.txt deleted file mode 100644 index 84893a889..000000000 --- a/python/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy -pytest -scipy >= 1.7.1 diff --git a/python/setup.py b/python/setup.py index 28194f41e..1171ad0a8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -126,7 +126,11 @@ setup( description="A language and compiler for custom Deep Learning operations", long_description="", packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], - install_requires=["cmake", "torch", "filelock"], + install_requires=[ + "cmake", + "filelock", + "torch", + ], package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], @@ -142,4 +146,16 @@ setup( "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", ], + extras_require={ + "tests": [ + "numpy", + "pytest", + "scipy>=1.7.1", + ], + "tutorials": [ + "matplotlib", + "pandas", + "tabulate", + ], + }, ) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 012ff65d7..f6e7ec237 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -54,6 +54,7 @@ matmul_data = { @pytest.mark.parametrize('M, N, K', matmul_data.keys()) def test_matmul(M, N, K): + torch.manual_seed(0) ref_gpu_util = matmul_data[(M, N, K)]['v100'] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 @@ -99,6 +100,7 @@ elementwise_data = { @pytest.mark.parametrize('N', elementwise_data.keys()) def test_elementwise(N): + torch.manual_seed(0) ref_gpu_util = elementwise_data[N]['v100'] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 30e507b0d..e5559ca7f 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -133,7 +133,7 @@ torch.manual_seed(0) x = torch.randn(1823, 781, device='cuda') y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) -print(torch.allclose(y_triton, y_torch)) +assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) # %% # As expected, the results are identical. diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 240583df2..ddfe9c0bc 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -237,9 +237,9 @@ def matmul_kernel( a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here - # while the accumulator is still in FP32 ! - if meta['ACTIVATION']: - accumulator = meta['ACTIVATION'](accumulator) + # while the accumulator is still in FP32! + if ACTIVATION: + accumulator = ACTIVATION(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index 5c4f53435..cf172537a 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -42,9 +42,8 @@ def _dropout( output_ptr, # pointer to the output n_elements, # number of elements in the `x` tensor p, # probability that an element of `x` is changed to zero - **meta, + BLOCK_SIZE: tl.constexpr, ): - BLOCK_SIZE = meta['BLOCK_SIZE'] pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -108,10 +107,9 @@ def _seeded_dropout( n_elements, p, seed, - **meta, + BLOCK_SIZE: tl.constexpr, ): # compute memory offsets of elements handled by this instance - BLOCK_SIZE = meta['BLOCK_SIZE'] pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 82231e15c..1cefc60b9 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -8,11 +8,19 @@ import torch import triton import triton.language as tl +try: + # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it + # should not be added to extras_require in setup.py. + import apex + HAS_APEX = True +except ModuleNotFoundError: + HAS_APEX = False + # Forward Pass @triton.jit -def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] +def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, + BLOCK_SIZE: tl.constexpr): # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE) @@ -42,11 +50,8 @@ def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META): # Backward pass (DX + partial DW + partial DB) @triton.jit -def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, - stride, N, eps, - **META): - GROUP_SIZE_M = META['GROUP_SIZE_M'] - BLOCK_SIZE_N = META['BLOCK_SIZE_N'] +def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps, + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) @@ -102,15 +107,14 @@ def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, @triton.jit -def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): +def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): pid = tl.program_id(0) - BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] - BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, meta['BLOCK_SIZE_M']) + rows = i + tl.arange(0, BLOCK_SIZE_M) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None] * N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) @@ -216,8 +220,8 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): x_names=['N'], x_vals=[512 * i for i in range(2, 32)], line_arg='provider', - line_vals=['triton', 'torch', 'apex'], - line_names=['Triton', 'Torch', 'Apex'], + line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), + line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', @@ -239,7 +243,6 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c if provider == 'torch': y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) if provider == 'apex': - import apex apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) y_fwd = lambda: apex_layer_norm(x) # forward pass diff --git a/python/tutorials/README.rst b/python/tutorials/README.rst index 24c752842..a36a08bbe 100644 --- a/python/tutorials/README.rst +++ b/python/tutorials/README.rst @@ -1,4 +1,11 @@ Tutorials ================== -Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. \ No newline at end of file +Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. + +To install the dependencies for the tutorials: + +.. code-block:: bash + + cd triton + pip install -e './python[tutorials]'