[GH-PAGES] Updated website
@@ -15,7 +15,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- The reduction operators in Triton.\n"
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.\nYou will learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- Reduction operators in Triton.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -33,21 +33,21 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\n\n# Compute the row-wise softmax of x\ndef naive_softmax(x):\n # read MN elements ; write M elements\n x_max = torch.max(x, axis=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(x)\n # read MN elements ; write M elements\n denominator = torch.sum(numerator, axis=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
|
||||
"import torch\n\n\n# Compute the row-wise softmax of x\n@torch.jit.script\ndef naive_softmax(x):\n # read MN elements ; write M elements\n x_max = x.max(dim=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(x)\n # read MN elements ; write M elements\n denominator = numerator.sum(dim=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nThis solution would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nIn practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.\n\n"
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nDoing so would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nThe `torch.jit.script` flags aims to perform this kind of \"kernel fusion\" automatically but, as we will see later, it is still far from ideal.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" tiles and guard the memory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" each row and guard the memory operations properly if we want to handle any possible input shapes:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -58,7 +58,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):\n # row index\n m = tl.program_id(0)\n # col indices\n n = tl.arange(0, meta['BLOCK'])\n # the memory address of all the elements\n # that we want to load can be computed as follows\n X = X + m * stride_xm + n\n x = tl.load(X, mask=n < N, other=-float('inf'))\n # Substract maximum for numerical stability\n z = x - tl.max(x, axis=0)\n # Note that exponentials in Triton are fast\n # but approximate (i.e., think __expf in CUDA)\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n y = num / denom\n # Write back to Y\n Y = Y + m * stride_ym + n\n tl.store(Y, y, mask=n < N)"
|
||||
"import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):\n # row index\n m = tl.program_id(0)\n # col indices\n # here BLOCK is the smallest power of two greater than `N`\n n = tl.arange(0, meta['BLOCK'])\n # the memory address of all the elements\n # that we want to load can be computed as follows\n X = X + m * stride_xm + n\n x = tl.load(X, mask=n < N, other=-float('inf'))\n # Substract maximum for numerical stability\n z = x - tl.max(x, axis=0)\n # Note that exponentials in Triton are fast\n # but approximate (i.e., think __expf in CUDA)\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n y = num / denom\n # Write back to Y\n Y = Y + m * stride_ym + n\n tl.store(Y, y, mask=n < N)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n M, N = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to parallelize each\n # row-normalization more aggressively -- i.e., with more warps -- vectors\n # that are longer\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix\n _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)\n return y"
|
||||
"def next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n M, N = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix\n _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)\n return y"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -126,14 +126,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``\n line_names=[\"Torch\", \"Triton\", 'Naive'], # label name for the lines\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'naive':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True)"
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``\n line_names=[\"Triton\", \"Torch (native)\", \"Torch (jit)\"], # label name for the lines\n styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch-native':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'torch-jit':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In the above plot, we can see that:\n\n - Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.\n - Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n"
|
||||
"In the above plot, we can see that:\n\n - Triton is 2-3x faster than the Torch JIT.\n - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -153,7 +153,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -3,9 +3,9 @@ Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
- The basic programming model of Triton
|
||||
- The `triton.jit` decorator, which is used to define Triton kernels.
|
||||
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
@@ -41,28 +41,28 @@ def _add(
|
||||
|
||||
|
||||
# %%
|
||||
# We can also declara a helper function that handles allocating the output vector
|
||||
# and enqueueing the kernel.
|
||||
# Let's also declare a helper function that to (1) allocate the output vector
|
||||
# and (2) enqueueing the above kernel.
|
||||
|
||||
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
# running asynchronously at this point.
|
||||
return z
|
||||
|
||||
|
||||
# %%
|
||||
# We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
@@ -81,7 +81,7 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
||||
# Benchmark
|
||||
# -----------
|
||||
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
|
||||
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
# for different problem sizes.
|
||||
|
||||
|
||||
@@ -112,4 +112,4 @@ def benchmark(size, provider):
|
||||
# %%
|
||||
# We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
|
||||
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
benchmark.run(show_plots=True)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
@@ -15,21 +15,21 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Matrix Multiplication\nIn this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.\nYou will specifically learn about:\n\n- Block-level matrix multiplications\n- Multi-dimensional pointer arithmetic\n- Program re-ordering for improved L2 cache hit rate \n- Automatic performance tuning\n"
|
||||
"\n# Matrix Multiplication\nIn this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.\nYou will specifically learn about:\n\n- Block-level matrix multiplications\n- Multi-dimensional pointer arithmetic\n- Program re-ordering for improved L2 cache hit rate \n- Automatic performance tuning\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Motivations\nMatrix multiplications are a key building block of most modern high-performance computing systems.\nThey are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called \"kernel libraries\" (e.g., cuBLAS).\nUnfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).\nFor this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.\n\nRoughly speaking, the kernel that we will write will implement the following blocked algorithm:\n\n .. code-block:: python\n\n # do in parallel\n for m in range(0, M, BLOCK_M):\n # do in parallel\n for n in range(0, N, BLOCK_N):\n acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)\n for k in range(0, K, BLOCK_K):\n a = A[m : m+BLOCK_M, k : k+BLOCK_K]\n b = B[k : k+BLOCK_K, n : n+BLOCK_N]\n acc += dot(a, b)\n C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;\n\nwhere each iteration of the doubly-nested for-loop corresponds to a Triton program instance.\n\n"
|
||||
"## Motivations\nMatrix multiplications are a key building block of most modern high-performance computing systems.\nThey are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called \"kernel libraries\" (e.g., cuBLAS).\nUnfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).\nIn this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.\n\nRoughly speaking, the kernel that we will write will implement the following blocked algorithm:\n\n .. code-block:: python\n\n # do in parallel\n for m in range(0, M, BLOCK_M):\n # do in parallel\n for n in range(0, N, BLOCK_N):\n acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)\n for k in range(0, K, BLOCK_K):\n a = A[m : m+BLOCK_M, k : k+BLOCK_K]\n b = B[k : k+BLOCK_K, n : n+BLOCK_N]\n acc += dot(a, b)\n C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;\n\nwhere each iteration of the doubly-nested for-loop corresponds to a Triton program instance.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\n\nThe above algorithm is actually fairly straightforward to implement in Triton.\nThe main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.\n\n### Pointer Arithmetics\n\nFor a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.\nTherefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:\n\n .. code-block:: python\n\n &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];\n &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];\n\nWhich means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:\n\n .. code-block:: python\n\n pid_m = triton.program_id(0)\n pid_n = triton.program_id(1)\n rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)\n rk = triton.arange(0, BLOCK_K)\n // pointer for A operand\n pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);\n // pointer for B operand\n pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);\n\nThese pointers can then be updated in the inner loop as:\n\n .. code-block:: python\n\n pa += BLOCK_K * stride_a_1;\n pb += BLOCK_K * stride_b_0;\n\n\n### L2 Cache Optimizations\n\nAs mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.\nHowever, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.\nThis means that a naive row-major ordering:\n\n .. code-block:: Python\n\n pid = triton.program_id(0);\n grid_m = (M + BLOCK_M - 1) // BLOCK_M;\n grid_n = (N + BLOCK_N - 1) // BLOCK_N;\n pid_m = pid / grid_n;\n pid_n = pid % grid_n;\n\nis unlikely to result in optimal performance.\n\nOne possible solution is to launch blocks in an order that promotes data reuse.\nThis can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:\n\n .. code-block:: python\n\n pid = triton.program_id(0);\n width = GROUP_M * grid_n;\n group_id = pid // width;\n # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M);\n pid_m = group_id * GROUP_M + (pid % group_size);\n pid_n = (pid % width) // (group_size);\n\nIn practice, this can improve the performance of our matrix multiplication kernel by >10\\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).\n\n\n"
|
||||
"## Compute Kernel\n\nThe above algorithm is, actually, fairly straightforward to implement in Triton.\nThe main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.\n\n### Pointer Arithmetics\n\nFor a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given by :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.\nTherefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :code:`B[k : k+BLOCK_K, n : n+BLOCK_N]` can be defined in pseudo-code as:\n\n .. code-block:: python\n\n &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);\n &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);\n\nWhich means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:\n\n .. code-block:: python\n\n pid_m = triton.program_id(0)\n pid_n = triton.program_id(1)\n rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)\n rk = triton.arange(0, BLOCK_K)\n // pointer for A operand\n pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);\n // pointer for B operand\n pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);\n\nAnd then updated in the inner loop as follows:\n\n .. code-block:: python\n\n pa += BLOCK_K * stride_a_1;\n pb += BLOCK_K * stride_b_0;\n\n\n### L2 Cache Optimizations\n\nAs mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.\nIt is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.\nAnd unfortunately, a simple row-major ordering\n\n .. code-block:: Python\n\n pid = triton.program_id(0);\n grid_m = (M + BLOCK_M - 1) // BLOCK_M;\n grid_n = (N + BLOCK_N - 1) // BLOCK_N;\n pid_m = pid / grid_n;\n pid_n = pid % grid_n;\n\nis just not going to cut it.\n\nOne possible solution is to launch blocks in an order that promotes data reuse.\nThis can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:\n\n .. code-block:: python\n\n pid = triton.program_id(0);\n width = GROUP_M * grid_n;\n group_id = pid // width;\n # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M);\n pid_m = group_id * GROUP_M + (pid % group_size);\n pid_n = (pid % width) // (group_size);\n\nIn practice, this can improve the performance of our matrix multiplication kernel by >10\\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).\n\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -47,14 +47,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton\nimport triton.language as tl\n\n# %\n# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:\n# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try\n# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs\n\n\n@triton.jit\ndef sigmoid(x):\n ret_true = 1 / (1 + tl.exp(-x))\n ret_false = tl.exp(x) / (1 + tl.exp(x))\n return tl.where(x >= 0, ret_true, ret_false)\n\n\n@triton.jit\ndef swish(x):\n return x * sigmoid(x)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n)\n# %\n# We can now define our kernel as normal, using all the techniques presented above\n@triton.jit\ndef _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):\n # extract meta-parameters\n BLOCK_M = META['BLOCK_M']\n BLOCK_N = META['BLOCK_N']\n BLOCK_K = META['BLOCK_K']\n GROUP_M = 8\n # matrix multiplication\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n a = tl.load(A)\n b = tl.load(B)\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n # triton can accept arbitrary activation function\n # via metaparameters!\n if META['ACTIVATION']:\n acc = META['ACTIVATION'](acc)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm[:, None] < M) & (rn[None, :] < N)\n tl.store(C, acc, mask=mask)"
|
||||
"import torch\nimport triton\nimport triton.language as tl\n\n# %\n# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:\n# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try\n# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\\\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\\\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\\\n triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\\\n triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\\\n triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\n #triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n)\n# %\n# We can now define our kernel as normal, using all the techniques presented above\n@triton.jit\ndef _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):\n # extract meta-parameters\n BLOCK_M = META['BLOCK_M']\n BLOCK_N = META['BLOCK_N']\n BLOCK_K = META['BLOCK_K']\n GROUP_M = 8\n # matrix multiplication\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n a = tl.load(A)\n b = tl.load(B)\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n # triton can accept arbitrary activation function\n # via metaparameters!\n if META['ACTIVATION']:\n acc = META['ACTIVATION'](acc)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm[:, None] < M) & (rn[None, :] < N)\n tl.store(C, acc, mask=mask)\n\n\n# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01*x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also create a convenience wrapper function that only takes two input tensors\nand (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel\n\n"
|
||||
"We can now create a convenience wrapper function that only takes two input tensors\nand (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,14 +65,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def matmul(a, b, activation=None):\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n assert a.is_contiguous(), \"matrix A must be contiguous\"\n assert b.is_contiguous(), \"matrix B must be contiguous\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n # launch kernel\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )\n _matmul[grid](\n a, b, c, M, N, K, \\\n a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\\\n ACTIVATION = activation\n )\n # return output\n return c"
|
||||
"def matmul(a, b, activation=None):\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n assert a.is_contiguous(), \"matrix A must be contiguous\"\n assert b.is_contiguous(), \"matrix B must be contiguous\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n # launch kernel\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )\n pgm = _matmul[grid](\n a, b, c, M, N, K, \\\n a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\\\n ACTIVATION = activation\n )\n # done; return the output tensor\n return c"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Unit Test\n\nWe can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)\n\n"
|
||||
"## Unit Test\n\nWe can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -83,14 +83,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#torch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nc_0 = matmul(a, b, activation=swish)\nc_1 = torch.nn.SiLU()(torch.matmul(a, b))\nprint(c_0)\nprint(c_1)\nprint(triton.testing.allclose(c_0, c_1))"
|
||||
"torch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nc_0 = matmul(a, b, activation=None)\nc_1 = torch.matmul(a, b)\nprint(c_0)\nprint(c_1)\nprint(triton.testing.allclose(c_0, c_1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmark\n\n### Square Matrix Performance\nWe can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#\n\n"
|
||||
"## Benchmark\n\n### Square Matrix Performance\nWe can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -101,7 +101,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot\n x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['cublas', 'triton'], # possible values for `line_arg``\n line_names=[\"cuBLAS\", \"Triton\"], # label name for the lines\n ylabel=\"TFLOPS\", # label name for the y-axis\n plot_name=\"matmul-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={}\n )\n)\ndef benchmark(M, N, K, provider):\n silu = torch.nn.SiLU()\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n if provider == 'cublas':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))\n perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)\n return perf(ms), perf(max_ms), perf(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot\n x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`\n line_arg='provider', # argument name whose value corresponds to a different line in the plot\n line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``\n line_names=[\"cuBLAS\", \"cuBLAS (+ torch.nn.LeakyReLU)\", \"Triton\", \"Triton (+ LeakyReLU)\"], # label name for the lines\n styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles\n ylabel=\"TFLOPS\", # label name for the y-axis\n plot_name=\"matmul-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={}\n )\n)\ndef benchmark(M, N, K, provider):\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n if provider == 'cublas':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))\n if provider == 'cublas + relu':\n torch_relu = torch.nn.ReLU(inplace=True)\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))\n if provider == 'triton + relu':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))\n perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)\n return perf(ms), perf(max_ms), perf(min_ms)\n\n\nbenchmark.run(show_plots=True, print_data=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -121,7 +121,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
@@ -14,9 +14,9 @@ You will specifically learn about:
|
||||
# Motivations
|
||||
# -------------
|
||||
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
# They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
# For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
# They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
# Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
# In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
|
||||
#
|
||||
# Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
#
|
||||
@@ -39,8 +39,8 @@ You will specifically learn about:
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
#
|
||||
# The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
# The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
# The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
|
||||
#
|
||||
# Pointer Arithmetics
|
||||
# ~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -50,10 +50,10 @@ You will specifically learn about:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
# &A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
|
||||
# &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
|
||||
#
|
||||
# Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
@@ -67,7 +67,7 @@ You will specifically learn about:
|
||||
# // pointer for B operand
|
||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
#
|
||||
# These pointers can then be updated in the inner loop as:
|
||||
# And then updated in the inner loop as follows:
|
||||
#
|
||||
# .. code-block:: python
|
||||
#
|
||||
@@ -79,8 +79,8 @@ You will specifically learn about:
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
# However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
# This means that a naive row-major ordering:
|
||||
# It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
# And unfortunately, a simple row-major ordering
|
||||
#
|
||||
# .. code-block:: Python
|
||||
#
|
||||
@@ -90,7 +90,7 @@ You will specifically learn about:
|
||||
# pid_m = pid / grid_n;
|
||||
# pid_n = pid % grid_n;
|
||||
#
|
||||
# is unlikely to result in optimal performance.
|
||||
# is just not going to cut it.
|
||||
#
|
||||
# One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
@@ -122,23 +122,19 @@ import triton.language as tl
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + tl.exp(-x))
|
||||
ret_false = tl.exp(x) / (1 + tl.exp(x))
|
||||
return tl.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -186,10 +182,14 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# %%
|
||||
# We can also create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01*x)
|
||||
|
||||
# %%
|
||||
# We can now create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
@@ -202,12 +202,12 @@ def matmul(a, b, activation=None):
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
# done; return the output tensor
|
||||
return c
|
||||
|
||||
|
||||
@@ -215,13 +215,13 @@ def matmul(a, b, activation=None):
|
||||
# Unit Test
|
||||
# -----------
|
||||
#
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
#torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
c_0 = matmul(a, b, activation=None)
|
||||
c_1 = torch.matmul(a, b)
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
@@ -232,29 +232,34 @@ print(triton.testing.allclose(c_0, c_1))
|
||||
#
|
||||
# Square Matrix Performance
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The reduction operators in Triton.
|
||||
- Reduction operators in Triton.
|
||||
"""
|
||||
|
||||
# %%
|
||||
@@ -17,15 +18,16 @@ import torch
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
# read MN elements ; write M elements
|
||||
x_max = torch.max(x, axis=1)[0]
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(x)
|
||||
# read MN elements ; write M elements
|
||||
denominator = torch.sum(numerator, axis=1)
|
||||
denominator = numerator.sum(dim=1)
|
||||
# read 2MN elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||
@@ -35,15 +37,15 @@ def naive_softmax(x):
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
# This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
# Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------
|
||||
# Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
# Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
# so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -54,6 +56,7 @@ def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||
# row index
|
||||
m = tl.program_id(0)
|
||||
# col indices
|
||||
# here BLOCK is the smallest power of two greater than `N`
|
||||
n = tl.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
@@ -90,11 +93,10 @@ def softmax(x):
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
# that are longer
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
@@ -132,10 +134,11 @@ print(torch.allclose(y_tri, y_ref))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
|
||||
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
|
||||
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
@@ -143,22 +146,22 @@ print(torch.allclose(y_tri, y_ref))
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch':
|
||||
if provider == 'torch-native':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'naive':
|
||||
if provider == 'torch-jit':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
#
|
||||
# - Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.
|
||||
# - Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
|
||||
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.
|
||||
# - Triton is 2-3x faster than the Torch JIT.
|
||||
# - Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
|
||||
# This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
|
||||
# Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
|
||||
@@ -15,7 +15,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model used by Triton\n- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.\n- The best practices for validating and benchmarking custom ops against native reference implementations\n"
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple vector addition using Triton and learn about:\n\n- The basic programming model of Triton\n- The `triton.jit` decorator, which is used to define Triton kernels.\n- The best practices for validating and benchmarking your custom ops against native reference implementations\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -40,7 +40,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also declara a helper function that handles allocating the output vector\nand enqueueing the kernel.\n\n"
|
||||
"Let's also declare a helper function that to (1) allocate the output vector\nand (2) enqueueing the above kernel.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -51,14 +51,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def add(x, y):\n z = torch.empty_like(x)\n N = z.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that should execute in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n # NOTE:\n # - torch.tensor objects are implicitly converted to pointers to their first element.\n # - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n _add[grid](x, y, z, N, BLOCK=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously.\n return z"
|
||||
"def add(x, y):\n z = torch.empty_like(x)\n N = z.shape[0]\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )\n # NOTE:\n # - each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel\n # - don't forget to pass meta-parameters as keywords arguments\n _add[grid](x, y, z, N, BLOCK=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return z"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:\n\n"
|
||||
"We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -83,7 +83,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.\nfor different problem sizes.\n\n"
|
||||
"## Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.\nTo make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops\nfor different problem sizes.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -112,7 +112,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"benchmark.run(show_plots=True)"
|
||||
"benchmark.run(print_data=True, show_plots=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -132,7 +132,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 23 KiB |
|
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 24 KiB |
|
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 55 KiB |
|
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 32 KiB |
@@ -22,9 +22,9 @@ Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple vector addition using Triton and learn about:
|
||||
|
||||
- The basic programming model used by Triton
|
||||
- The `triton.jit` decorator, which constitutes the main entry point for writing Triton kernels.
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
- The basic programming model of Triton
|
||||
- The `triton.jit` decorator, which is used to define Triton kernels.
|
||||
- The best practices for validating and benchmarking your custom ops against native reference implementations
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-14
|
||||
|
||||
@@ -73,8 +73,8 @@ Compute Kernel
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-46
|
||||
|
||||
We can also declara a helper function that handles allocating the output vector
|
||||
and enqueueing the kernel.
|
||||
Let's also declare a helper function that to (1) allocate the output vector
|
||||
and (2) enqueueing the above kernel.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 46-64
|
||||
|
||||
@@ -85,16 +85,16 @@ and enqueueing the kernel.
|
||||
def add(x, y):
|
||||
z = torch.empty_like(x)
|
||||
N = z.shape[0]
|
||||
# The SPMD launch grid denotes the number of kernel instances that should execute in parallel.
|
||||
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
|
||||
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
|
||||
# NOTE:
|
||||
# - torch.tensor objects are implicitly converted to pointers to their first element.
|
||||
# - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel
|
||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
||||
# - don't forget to pass meta-parameters as keywords arguments
|
||||
_add[grid](x, y, z, N, BLOCK=1024)
|
||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||
# running asynchronously.
|
||||
# running asynchronously at this point.
|
||||
return z
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ and enqueueing the kernel.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-66
|
||||
|
||||
We can now use the above function to compute the sum of two `torch.tensor` objects and test our results:
|
||||
We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 66-77
|
||||
|
||||
@@ -150,7 +150,7 @@ Seems like we're good to go!
|
||||
Benchmark
|
||||
-----------
|
||||
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
for different problem sizes.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 86-112
|
||||
@@ -199,7 +199,7 @@ We can now run the decorated function above. Pass `show_plots=True` to see the p
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_01-vector-add_001.png
|
||||
@@ -207,13 +207,38 @@ We can now run the decorated function above. Pass `show_plots=True` to see the p
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
vector-add-performance:
|
||||
size Torch Triton
|
||||
0 4096.0 9.600000 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 76.800002 63.999998
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
10 4194304.0 780.190482 780.190482
|
||||
11 8388608.0 812.429770 812.429770
|
||||
12 16777216.0 833.084721 833.084721
|
||||
13 33554432.0 843.811163 843.811163
|
||||
14 67108864.0 848.362445 849.278610
|
||||
15 134217728.0 850.656574 851.577704
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 7.682 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 11.005 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
||||
@@ -20,19 +20,20 @@
|
||||
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:
|
||||
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM.
|
||||
You will learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The reduction operators in Triton.
|
||||
- Reduction operators in Triton.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 11-15
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-16
|
||||
|
||||
Motivations
|
||||
------------
|
||||
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-35
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 16-37
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -41,15 +42,16 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
@torch.jit.script
|
||||
def naive_softmax(x):
|
||||
# read MN elements ; write M elements
|
||||
x_max = torch.max(x, axis=1)[0]
|
||||
x_max = x.max(dim=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(x)
|
||||
# read MN elements ; write M elements
|
||||
denominator = torch.sum(numerator, axis=1)
|
||||
denominator = numerator.sum(dim=1)
|
||||
# read 2MN elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||
@@ -63,22 +65,22 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 36-40
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 38-42
|
||||
|
||||
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip.
|
||||
This solution would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.
|
||||
Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~5x (i.e., :math:`(10MN + 2M) / 2MN`).
|
||||
The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 42-47
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 44-49
|
||||
|
||||
Compute Kernel
|
||||
----------------
|
||||
Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 47-74
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 49-77
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -92,6 +94,7 @@ so we need to internally "pad" tiles and guard the memory operations properly if
|
||||
# row index
|
||||
m = tl.program_id(0)
|
||||
# col indices
|
||||
# here BLOCK is the smallest power of two greater than `N`
|
||||
n = tl.arange(0, meta['BLOCK'])
|
||||
# the memory address of all the elements
|
||||
# that we want to load can be computed as follows
|
||||
@@ -116,11 +119,11 @@ so we need to internally "pad" tiles and guard the memory operations properly if
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 75-76
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 78-79
|
||||
|
||||
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 76-108
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 79-110
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -141,11 +144,10 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
|
||||
M, N = x.shape
|
||||
# The block size is the smallest power of two greater than the number of columns in `x`
|
||||
BLOCK = next_power_of_2(N)
|
||||
# Another trick we can use is to ask the compiler to parallelize each
|
||||
# row-normalization more aggressively -- i.e., with more warps -- vectors
|
||||
# that are longer
|
||||
# Another trick we can use is to ask the compiler to use more threads per row by
|
||||
# increasing the number of warps (`num_warps`) over which each row is distributed.
|
||||
# You will see in the next tutorial how to auto-tune this value in a more natural
|
||||
# way so you don't have to come up with manual heuristics yourself
|
||||
# way so you don't have to come up with manual heuristics yourself.
|
||||
num_warps = 4
|
||||
if BLOCK >= 2048: num_warps = 8
|
||||
if BLOCK >= 4096: num_warps = 16
|
||||
@@ -163,17 +165,17 @@ We can create a helper function that enqueues the kernel and its (meta-)argument
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 109-111
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 111-113
|
||||
|
||||
Unit Test
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 113-115
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-117
|
||||
|
||||
We make sure that we test our kernel on a matrix with an irregular number of rows and columns.
|
||||
This will allow us to verify that our padding mechanism works.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-122
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 117-124
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -199,18 +201,18 @@ This will allow us to verify that our padding mechanism works.
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 123-124
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 125-126
|
||||
|
||||
As expected, the results are identical.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 126-130
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 128-132
|
||||
|
||||
Benchmark
|
||||
-------------
|
||||
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 130-158
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 132-161
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -219,10 +221,11 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
|
||||
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||
line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg``
|
||||
line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines
|
||||
styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
@@ -230,17 +233,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch':
|
||||
if provider == 'torch-native':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'naive':
|
||||
if provider == 'torch-jit':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
|
||||
|
||||
@@ -250,22 +253,44 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 512.000001 546.133347 273.066674
|
||||
1 384.0 585.142862 585.142862 267.130429
|
||||
2 512.0 630.153853 606.814814 264.258068
|
||||
3 640.0 682.666684 640.000002 269.473696
|
||||
4 768.0 702.171410 664.216187 273.066663
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 405.755985 329.483481
|
||||
94 12288.0 812.429770 415.661740 329.602681
|
||||
95 12416.0 810.840807 412.149375 329.173158
|
||||
96 12544.0 810.925276 412.971190 329.292871
|
||||
97 12672.0 811.007961 412.097543 329.142870
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 159-164
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 162-167
|
||||
|
||||
In the above plot, we can see that:
|
||||
|
||||
- Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.
|
||||
- Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
|
||||
This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.
|
||||
- Triton is 2-3x faster than the Torch JIT.
|
||||
- Triton is even faster than :code:`torch.softmax`. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.
|
||||
This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of memory necessary.
|
||||
Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 20.250 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 8.184 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
Matrix Multiplication
|
||||
======================
|
||||
In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:
|
||||
|
||||
- Block-level matrix multiplications
|
||||
@@ -33,9 +33,9 @@ You will specifically learn about:
|
||||
Motivations
|
||||
-------------
|
||||
Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.
|
||||
|
||||
Roughly speaking, the kernel that we will write will implement the following blocked algorithm:
|
||||
|
||||
@@ -59,8 +59,8 @@ where each iteration of the doubly-nested for-loop corresponds to a Triton progr
|
||||
Compute Kernel
|
||||
----------------
|
||||
|
||||
The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of :code:`A` and :code:`B` that we need to read in the inner loop.
|
||||
The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks of :code:`A` and :code:`B` must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.
|
||||
|
||||
Pointer Arithmetics
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -70,10 +70,10 @@ Therefore, blocks of pointers for :code:`A[m : m+BLOCK_M, k:k+BLOCK_K]` and :cod
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :];
|
||||
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :];
|
||||
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
|
||||
&B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
|
||||
|
||||
Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks of A and B can be initialized in Triton as:
|
||||
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -87,7 +87,7 @@ Which means that, at initialization (i.e., :code:`k = 0`), pointers for blocks o
|
||||
// pointer for B operand
|
||||
pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
||||
|
||||
These pointers can then be updated in the inner loop as:
|
||||
And then updated in the inner loop as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -99,8 +99,8 @@ L2 Cache Optimizations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
As mentioned above, each program instance computes an :code:`[BLOCK_M, BLOCK_N]` block of :code:`C`.
|
||||
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
This means that a naive row-major ordering:
|
||||
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
And unfortunately, a simple row-major ordering
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
@@ -110,7 +110,7 @@ This means that a naive row-major ordering:
|
||||
pid_m = pid / grid_n;
|
||||
pid_n = pid % grid_n;
|
||||
|
||||
is unlikely to result in optimal performance.
|
||||
is just not going to cut it.
|
||||
|
||||
One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before switching to the next column:
|
||||
@@ -134,7 +134,7 @@ Final Result
|
||||
-------------
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-189
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 115-190
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -148,23 +148,19 @@ Final Result
|
||||
# - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
|
||||
# - A autotuning *key* whose change in values will trigger evaluation of all the provided configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def sigmoid(x):
|
||||
ret_true = 1 / (1 + tl.exp(-x))
|
||||
ret_false = tl.exp(x) / (1 + tl.exp(x))
|
||||
return tl.where(x >= 0, ret_true, ret_false)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -212,6 +208,10 @@ Final Result
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01*x)
|
||||
|
||||
|
||||
|
||||
@@ -219,17 +219,17 @@ Final Result
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 190-192
|
||||
|
||||
We can also create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 191-193
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 192-214
|
||||
We can now create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 193-214
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
@@ -241,12 +241,12 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
# return output
|
||||
# done; return the output tensor
|
||||
return c
|
||||
|
||||
|
||||
@@ -262,18 +262,18 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
Unit Test
|
||||
-----------
|
||||
|
||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 219-229
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
#torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
c_0 = matmul(a, b, activation=None)
|
||||
c_1 = torch.matmul(a, b)
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
@@ -288,32 +288,22 @@ We can test our custom matrix multiplication operation against a native torch im
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([[-4.5061e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3251e-03, -0.0000e+00],
|
||||
[-1.0967e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4603e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||||
...,
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5215e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7285e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.5993e-03]], device='cuda:0', dtype=torch.float16)
|
||||
tensor([[-4.4942e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3232e-03, -0.0000e+00],
|
||||
[-1.1003e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4639e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||||
...,
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5276e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7345e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.6031e-03]], device='cuda:0', dtype=torch.float16)
|
||||
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
|
||||
|
||||
@@ -326,9 +316,9 @@ Benchmark
|
||||
|
||||
Square Matrix Performance
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#
|
||||
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 236-262
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 236-268
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
@@ -337,23 +327,28 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
silu = torch.nn.SiLU()
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
@@ -361,6 +356,7 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_03-matrix-multiplication_001.png
|
||||
:alt: 03 matrix multiplication
|
||||
:class: sphx-glr-single-img
|
||||
@@ -372,38 +368,40 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
M cuBLAS Triton
|
||||
0 512.0 20.164923 15.420235
|
||||
1 768.0 58.982401 40.215272
|
||||
2 1024.0 95.325090 72.315584
|
||||
3 1280.0 151.703703 117.028568
|
||||
4 1536.0 153.867127 150.593357
|
||||
5 1792.0 208.137481 190.498706
|
||||
6 2048.0 202.135135 151.146088
|
||||
7 2304.0 251.451276 178.267699
|
||||
8 2560.0 237.449270 218.453323
|
||||
9 2816.0 238.329010 200.987140
|
||||
10 3072.0 243.017615 223.806730
|
||||
11 3328.0 244.868356 210.500857
|
||||
12 3584.0 250.460703 232.941430
|
||||
13 3840.0 256.593972 225.697957
|
||||
14 4096.0 266.305018 247.634187
|
||||
15 4352.0 247.675667 237.797917
|
||||
16 4608.0 280.621108 260.713476
|
||||
17 4864.0 272.431168 252.534501
|
||||
18 5120.0 265.596772 245.223576
|
||||
19 5376.0 261.381955 244.335299
|
||||
20 5632.0 283.439220 260.383339
|
||||
21 5888.0 276.674704 254.103421
|
||||
22 6144.0 274.869441 252.078378
|
||||
23 6400.0 269.190319 249.027231
|
||||
24 6656.0 269.252160 249.104840
|
||||
25 6912.0 267.069377 247.115909
|
||||
26 7168.0 268.504352 246.006552
|
||||
27 7424.0 267.373291 246.355964
|
||||
28 7680.0 266.406511 245.760004
|
||||
29 7936.0 228.348876 248.331598
|
||||
30 8192.0 227.680622 247.977332
|
||||
matmul-performance:
|
||||
M cuBLAS cuBLAS (+ torch.nn.LeakyReLU) Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 0.372364 0.512000 0.512000
|
||||
1 256.0 2.978909 2.340571 3.276800 2.978909
|
||||
2 384.0 7.372800 6.144000 8.507077 8.507077
|
||||
3 512.0 14.563555 11.915636 16.384000 16.384000
|
||||
4 640.0 22.260869 18.285714 23.272727 23.272727
|
||||
5 768.0 32.768000 26.810182 34.028308 34.028308
|
||||
6 896.0 39.025776 32.672744 39.025776 39.025776
|
||||
7 1024.0 49.932191 41.943041 52.428801 52.428801
|
||||
8 1152.0 44.566925 38.779015 46.656000 46.656000
|
||||
9 1280.0 51.200001 44.521738 56.109587 56.109587
|
||||
10 1408.0 64.138541 55.068446 65.684049 59.258433
|
||||
11 1536.0 79.526831 67.408458 75.296679 75.296679
|
||||
12 1664.0 63.372618 55.893862 61.636381 61.636381
|
||||
13 1792.0 72.983276 63.860363 68.953520 68.953520
|
||||
14 1920.0 66.782607 61.168141 68.776119 68.776119
|
||||
15 2048.0 73.262953 65.793006 75.234154 75.234154
|
||||
16 2176.0 82.473969 73.712993 79.540109 79.855747
|
||||
17 2304.0 68.251065 62.207998 73.051599 73.051599
|
||||
18 2432.0 71.305746 65.033481 80.963875 80.963875
|
||||
19 2560.0 77.649287 70.773218 76.560748 75.851852
|
||||
20 2688.0 82.463163 75.413632 82.106182 80.880718
|
||||
21 2816.0 82.602666 73.424595 78.442822 77.330158
|
||||
22 2944.0 82.784108 72.966370 80.122235 80.122235
|
||||
23 3072.0 79.638683 74.997490 79.082550 82.903517
|
||||
24 3200.0 84.099871 78.335374 89.385477 85.333333
|
||||
25 3328.0 83.226931 77.828428 81.346098 81.530349
|
||||
26 3456.0 79.351933 75.276907 82.858753 81.435930
|
||||
27 3584.0 87.466332 81.518940 95.858629 91.470385
|
||||
28 3712.0 84.230479 79.283603 81.682211 85.455380
|
||||
29 3840.0 84.421376 79.562590 87.355452 87.562949
|
||||
30 3968.0 93.006050 86.296981 84.038524 84.504108
|
||||
31 4096.0 93.662059 87.381330 83.729089 92.119235
|
||||
|
||||
|
||||
|
||||
@@ -411,7 +409,7 @@ We can now compare the performance of our kernel against CUTLASS. Here we focus
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 37.657 seconds)
|
||||
**Total running time of the script:** ( 2 minutes 12.630 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
||||
@@ -12,7 +12,7 @@ Below is a gallery of tutorials for writing various basic operations with Triton
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model used by Triton - The triton.jit decorator, which constitutes the ...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
@@ -33,7 +33,7 @@ Below is a gallery of tutorials for writing various basic operations with Triton
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The reduction operators in Tr...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**00:37.657** total execution time for **getting-started_tutorials** files:
|
||||
**03:31.819** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 00:37.657 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:12.630 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:00.000 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:08.184 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:00.000 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:11.005 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
||||
@@ -130,7 +130,7 @@ ul.search li a {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
ul.search li div.context {
|
||||
ul.search li p.context {
|
||||
color: #888;
|
||||
margin: 2px 0 0 30px;
|
||||
text-align: left;
|
||||
@@ -508,6 +508,63 @@ table.hlist td {
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
/* -- object description styles --------------------------------------------- */
|
||||
|
||||
.sig {
|
||||
font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace;
|
||||
}
|
||||
|
||||
.sig-name, code.descname {
|
||||
background-color: transparent;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.sig-name {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
code.descname {
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
.sig-prename, code.descclassname {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.optional {
|
||||
font-size: 1.3em;
|
||||
}
|
||||
|
||||
.sig-paren {
|
||||
font-size: larger;
|
||||
}
|
||||
|
||||
.sig-param.n {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* C++ specific styling */
|
||||
|
||||
.sig-inline.c-texpr,
|
||||
.sig-inline.cpp-texpr {
|
||||
font-family: unset;
|
||||
}
|
||||
|
||||
.sig.c .k, .sig.c .kt,
|
||||
.sig.cpp .k, .sig.cpp .kt {
|
||||
color: #0033B3;
|
||||
}
|
||||
|
||||
.sig.c .m,
|
||||
.sig.cpp .m {
|
||||
color: #1750EB;
|
||||
}
|
||||
|
||||
.sig.c .s, .sig.c .sc,
|
||||
.sig.cpp .s, .sig.cpp .sc {
|
||||
color: #067D17;
|
||||
}
|
||||
|
||||
|
||||
/* -- other body styles ----------------------------------------------------- */
|
||||
|
||||
@@ -634,14 +691,6 @@ dl.glossary dt {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.optional {
|
||||
font-size: 1.3em;
|
||||
}
|
||||
|
||||
.sig-paren {
|
||||
font-size: larger;
|
||||
}
|
||||
|
||||
.versionmodified {
|
||||
font-style: italic;
|
||||
}
|
||||
@@ -770,8 +819,12 @@ div.code-block-caption code {
|
||||
|
||||
table.highlighttable td.linenos,
|
||||
span.linenos,
|
||||
div.doctest > div.highlight span.gp { /* gp: Generic.Prompt */
|
||||
div.highlight span.gp { /* gp: Generic.Prompt */
|
||||
user-select: none;
|
||||
-webkit-user-select: text; /* Safari fallback only */
|
||||
-webkit-user-select: none; /* Chrome/Safari */
|
||||
-moz-user-select: none; /* Firefox */
|
||||
-ms-user-select: none; /* IE10+ */
|
||||
}
|
||||
|
||||
div.code-block-caption span.caption-number {
|
||||
@@ -786,16 +839,6 @@ div.literal-block-wrapper {
|
||||
margin: 1em 0;
|
||||
}
|
||||
|
||||
code.descname {
|
||||
background-color: transparent;
|
||||
font-weight: bold;
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
code.descclassname {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
code.xref, a code {
|
||||
background-color: transparent;
|
||||
font-weight: bold;
|
||||
|
||||
@@ -301,12 +301,14 @@ var Documentation = {
|
||||
window.location.href = prevHref;
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case 39: // right
|
||||
var nextHref = $('link[rel="next"]').prop('href');
|
||||
if (nextHref) {
|
||||
window.location.href = nextHref;
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -145,7 +145,7 @@ div.sphx-glr-download a:hover {
|
||||
background-color: #d5d57e;
|
||||
}
|
||||
|
||||
.sphx-glr-example-title > :target::before {
|
||||
.sphx-glr-example-title:target::before {
|
||||
display: block;
|
||||
content: "";
|
||||
margin-top: -50px;
|
||||
|
||||
@@ -276,7 +276,7 @@ var Search = {
|
||||
setTimeout(function() {
|
||||
displayNextItem();
|
||||
}, 5);
|
||||
} else if (DOCUMENTATION_OPTIONS.HAS_SOURCE) {
|
||||
} else {
|
||||
$.ajax({url: requestUrl,
|
||||
dataType: "text",
|
||||
complete: function(jqxhr, textstatus) {
|
||||
@@ -289,12 +289,6 @@ var Search = {
|
||||
displayNextItem();
|
||||
}, 5);
|
||||
}});
|
||||
} else {
|
||||
// no source available, just display title
|
||||
Search.output.append(listItem);
|
||||
setTimeout(function() {
|
||||
displayNextItem();
|
||||
}, 5);
|
||||
}
|
||||
}
|
||||
// search finished, update title and status message
|
||||
@@ -509,7 +503,7 @@ var Search = {
|
||||
var excerpt = ((start > 0) ? '...' : '') +
|
||||
$.trim(text.substr(start, 240)) +
|
||||
((start + 240 - text.length) ? '...' : '');
|
||||
var rv = $('<div class="context"></div>').text(excerpt);
|
||||
var rv = $('<p class="context"></p>').text(excerpt);
|
||||
$.each(hlwords, function() {
|
||||
rv = rv.highlightText(this, 'highlighted');
|
||||
});
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
(function (global, factory) {
|
||||
typeof exports === 'object' && typeof module !== 'undefined' ? module.exports = factory() :
|
||||
typeof define === 'function' && define.amd ? define('underscore', factory) :
|
||||
(global = global || self, (function () {
|
||||
(global = typeof globalThis !== 'undefined' ? globalThis : global || self, (function () {
|
||||
var current = global._;
|
||||
var exports = global._ = factory();
|
||||
exports.noConflict = function () { global._ = current; return exports; };
|
||||
}()));
|
||||
}(this, (function () {
|
||||
// Underscore.js 1.12.0
|
||||
// Underscore.js 1.13.1
|
||||
// https://underscorejs.org
|
||||
// (c) 2009-2020 Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors
|
||||
// (c) 2009-2021 Jeremy Ashkenas, Julian Gonggrijp, and DocumentCloud and Investigative Reporters & Editors
|
||||
// Underscore may be freely distributed under the MIT license.
|
||||
|
||||
// Current version.
|
||||
var VERSION = '1.12.0';
|
||||
var VERSION = '1.13.1';
|
||||
|
||||
// Establish the root object, `window` (`self`) in the browser, `global`
|
||||
// on the server, or `this` in some virtual machines. We use `self`
|
||||
@@ -170,7 +170,7 @@
|
||||
var isArray = nativeIsArray || tagTester('Array');
|
||||
|
||||
// Internal function to check whether `key` is an own property name of `obj`.
|
||||
function has(obj, key) {
|
||||
function has$1(obj, key) {
|
||||
return obj != null && hasOwnProperty.call(obj, key);
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@
|
||||
(function() {
|
||||
if (!isArguments(arguments)) {
|
||||
isArguments = function(obj) {
|
||||
return has(obj, 'callee');
|
||||
return has$1(obj, 'callee');
|
||||
};
|
||||
}
|
||||
}());
|
||||
@@ -268,7 +268,7 @@
|
||||
|
||||
// Constructor is a special case.
|
||||
var prop = 'constructor';
|
||||
if (has(obj, prop) && !keys.contains(prop)) keys.push(prop);
|
||||
if (has$1(obj, prop) && !keys.contains(prop)) keys.push(prop);
|
||||
|
||||
while (nonEnumIdx--) {
|
||||
prop = nonEnumerableProps[nonEnumIdx];
|
||||
@@ -284,7 +284,7 @@
|
||||
if (!isObject(obj)) return [];
|
||||
if (nativeKeys) return nativeKeys(obj);
|
||||
var keys = [];
|
||||
for (var key in obj) if (has(obj, key)) keys.push(key);
|
||||
for (var key in obj) if (has$1(obj, key)) keys.push(key);
|
||||
// Ahem, IE < 9.
|
||||
if (hasEnumBug) collectNonEnumProps(obj, keys);
|
||||
return keys;
|
||||
@@ -318,24 +318,24 @@
|
||||
// If Underscore is called as a function, it returns a wrapped object that can
|
||||
// be used OO-style. This wrapper holds altered versions of all functions added
|
||||
// through `_.mixin`. Wrapped objects may be chained.
|
||||
function _(obj) {
|
||||
if (obj instanceof _) return obj;
|
||||
if (!(this instanceof _)) return new _(obj);
|
||||
function _$1(obj) {
|
||||
if (obj instanceof _$1) return obj;
|
||||
if (!(this instanceof _$1)) return new _$1(obj);
|
||||
this._wrapped = obj;
|
||||
}
|
||||
|
||||
_.VERSION = VERSION;
|
||||
_$1.VERSION = VERSION;
|
||||
|
||||
// Extracts the result from a wrapped and chained object.
|
||||
_.prototype.value = function() {
|
||||
_$1.prototype.value = function() {
|
||||
return this._wrapped;
|
||||
};
|
||||
|
||||
// Provide unwrapping proxies for some methods used in engine operations
|
||||
// such as arithmetic and JSON stringification.
|
||||
_.prototype.valueOf = _.prototype.toJSON = _.prototype.value;
|
||||
_$1.prototype.valueOf = _$1.prototype.toJSON = _$1.prototype.value;
|
||||
|
||||
_.prototype.toString = function() {
|
||||
_$1.prototype.toString = function() {
|
||||
return String(this._wrapped);
|
||||
};
|
||||
|
||||
@@ -370,8 +370,8 @@
|
||||
// Internal recursive comparison function for `_.isEqual`.
|
||||
function deepEq(a, b, aStack, bStack) {
|
||||
// Unwrap any wrapped objects.
|
||||
if (a instanceof _) a = a._wrapped;
|
||||
if (b instanceof _) b = b._wrapped;
|
||||
if (a instanceof _$1) a = a._wrapped;
|
||||
if (b instanceof _$1) b = b._wrapped;
|
||||
// Compare `[[Class]]` names.
|
||||
var className = toString.call(a);
|
||||
if (className !== toString.call(b)) return false;
|
||||
@@ -463,7 +463,7 @@
|
||||
while (length--) {
|
||||
// Deep compare each member
|
||||
key = _keys[length];
|
||||
if (!(has(b, key) && eq(a[key], b[key], aStack, bStack))) return false;
|
||||
if (!(has$1(b, key) && eq(a[key], b[key], aStack, bStack))) return false;
|
||||
}
|
||||
}
|
||||
// Remove the first object from the stack of traversed objects.
|
||||
@@ -642,15 +642,15 @@
|
||||
|
||||
// Normalize a (deep) property `path` to array.
|
||||
// Like `_.iteratee`, this function can be customized.
|
||||
function toPath(path) {
|
||||
function toPath$1(path) {
|
||||
return isArray(path) ? path : [path];
|
||||
}
|
||||
_.toPath = toPath;
|
||||
_$1.toPath = toPath$1;
|
||||
|
||||
// Internal wrapper for `_.toPath` to enable minification.
|
||||
// Similar to `cb` for `_.iteratee`.
|
||||
function toPath$1(path) {
|
||||
return _.toPath(path);
|
||||
function toPath(path) {
|
||||
return _$1.toPath(path);
|
||||
}
|
||||
|
||||
// Internal function to obtain a nested property in `obj` along `path`.
|
||||
@@ -668,19 +668,19 @@
|
||||
// `undefined`, return `defaultValue` instead.
|
||||
// The `path` is normalized through `_.toPath`.
|
||||
function get(object, path, defaultValue) {
|
||||
var value = deepGet(object, toPath$1(path));
|
||||
var value = deepGet(object, toPath(path));
|
||||
return isUndefined(value) ? defaultValue : value;
|
||||
}
|
||||
|
||||
// Shortcut function for checking if an object has a given property directly on
|
||||
// itself (in other words, not on a prototype). Unlike the internal `has`
|
||||
// function, this public version can also traverse nested properties.
|
||||
function has$1(obj, path) {
|
||||
path = toPath$1(path);
|
||||
function has(obj, path) {
|
||||
path = toPath(path);
|
||||
var length = path.length;
|
||||
for (var i = 0; i < length; i++) {
|
||||
var key = path[i];
|
||||
if (!has(obj, key)) return false;
|
||||
if (!has$1(obj, key)) return false;
|
||||
obj = obj[key];
|
||||
}
|
||||
return !!length;
|
||||
@@ -703,7 +703,7 @@
|
||||
// Creates a function that, when passed an object, will traverse that object’s
|
||||
// properties down the given `path`, specified as an array of keys or indices.
|
||||
function property(path) {
|
||||
path = toPath$1(path);
|
||||
path = toPath(path);
|
||||
return function(obj) {
|
||||
return deepGet(obj, path);
|
||||
};
|
||||
@@ -747,12 +747,12 @@
|
||||
function iteratee(value, context) {
|
||||
return baseIteratee(value, context, Infinity);
|
||||
}
|
||||
_.iteratee = iteratee;
|
||||
_$1.iteratee = iteratee;
|
||||
|
||||
// The function we call internally to generate a callback. It invokes
|
||||
// `_.iteratee` if overridden, otherwise `baseIteratee`.
|
||||
function cb(value, context, argCount) {
|
||||
if (_.iteratee !== iteratee) return _.iteratee(value, context);
|
||||
if (_$1.iteratee !== iteratee) return _$1.iteratee(value, context);
|
||||
return baseIteratee(value, context, argCount);
|
||||
}
|
||||
|
||||
@@ -840,7 +840,7 @@
|
||||
|
||||
// By default, Underscore uses ERB-style template delimiters. Change the
|
||||
// following template settings to use alternative delimiters.
|
||||
var templateSettings = _.templateSettings = {
|
||||
var templateSettings = _$1.templateSettings = {
|
||||
evaluate: /<%([\s\S]+?)%>/g,
|
||||
interpolate: /<%=([\s\S]+?)%>/g,
|
||||
escape: /<%-([\s\S]+?)%>/g
|
||||
@@ -868,13 +868,20 @@
|
||||
return '\\' + escapes[match];
|
||||
}
|
||||
|
||||
// In order to prevent third-party code injection through
|
||||
// `_.templateSettings.variable`, we test it against the following regular
|
||||
// expression. It is intentionally a bit more liberal than just matching valid
|
||||
// identifiers, but still prevents possible loopholes through defaults or
|
||||
// destructuring assignment.
|
||||
var bareIdentifier = /^\s*(\w|\$)+\s*$/;
|
||||
|
||||
// JavaScript micro-templating, similar to John Resig's implementation.
|
||||
// Underscore templating handles arbitrary delimiters, preserves whitespace,
|
||||
// and correctly escapes quotes within interpolated code.
|
||||
// NB: `oldSettings` only exists for backwards compatibility.
|
||||
function template(text, settings, oldSettings) {
|
||||
if (!settings && oldSettings) settings = oldSettings;
|
||||
settings = defaults({}, settings, _.templateSettings);
|
||||
settings = defaults({}, settings, _$1.templateSettings);
|
||||
|
||||
// Combine delimiters into one regular expression via alternation.
|
||||
var matcher = RegExp([
|
||||
@@ -903,8 +910,17 @@
|
||||
});
|
||||
source += "';\n";
|
||||
|
||||
var argument = settings.variable;
|
||||
if (argument) {
|
||||
// Insure against third-party code injection. (CVE-2021-23358)
|
||||
if (!bareIdentifier.test(argument)) throw new Error(
|
||||
'variable is not a bare identifier: ' + argument
|
||||
);
|
||||
} else {
|
||||
// If a variable is not specified, place data values in local scope.
|
||||
if (!settings.variable) source = 'with(obj||{}){\n' + source + '}\n';
|
||||
source = 'with(obj||{}){\n' + source + '}\n';
|
||||
argument = 'obj';
|
||||
}
|
||||
|
||||
source = "var __t,__p='',__j=Array.prototype.join," +
|
||||
"print=function(){__p+=__j.call(arguments,'');};\n" +
|
||||
@@ -912,18 +928,17 @@
|
||||
|
||||
var render;
|
||||
try {
|
||||
render = new Function(settings.variable || 'obj', '_', source);
|
||||
render = new Function(argument, '_', source);
|
||||
} catch (e) {
|
||||
e.source = source;
|
||||
throw e;
|
||||
}
|
||||
|
||||
var template = function(data) {
|
||||
return render.call(this, data, _);
|
||||
return render.call(this, data, _$1);
|
||||
};
|
||||
|
||||
// Provide the compiled source as a convenience for precompilation.
|
||||
var argument = settings.variable || 'obj';
|
||||
template.source = 'function(' + argument + '){\n' + source + '}';
|
||||
|
||||
return template;
|
||||
@@ -933,7 +948,7 @@
|
||||
// is invoked with its parent as context. Returns the value of the final
|
||||
// child, or `fallback` if any child is undefined.
|
||||
function result(obj, path, fallback) {
|
||||
path = toPath$1(path);
|
||||
path = toPath(path);
|
||||
var length = path.length;
|
||||
if (!length) {
|
||||
return isFunction$1(fallback) ? fallback.call(obj) : fallback;
|
||||
@@ -959,7 +974,7 @@
|
||||
|
||||
// Start chaining a wrapped Underscore object.
|
||||
function chain(obj) {
|
||||
var instance = _(obj);
|
||||
var instance = _$1(obj);
|
||||
instance._chain = true;
|
||||
return instance;
|
||||
}
|
||||
@@ -993,7 +1008,7 @@
|
||||
return bound;
|
||||
});
|
||||
|
||||
partial.placeholder = _;
|
||||
partial.placeholder = _$1;
|
||||
|
||||
// Create a function bound to a given object (assigning `this`, and arguments,
|
||||
// optionally).
|
||||
@@ -1012,7 +1027,7 @@
|
||||
var isArrayLike = createSizePropertyCheck(getLength);
|
||||
|
||||
// Internal implementation of a recursive `flatten` function.
|
||||
function flatten(input, depth, strict, output) {
|
||||
function flatten$1(input, depth, strict, output) {
|
||||
output = output || [];
|
||||
if (!depth && depth !== 0) {
|
||||
depth = Infinity;
|
||||
@@ -1025,7 +1040,7 @@
|
||||
if (isArrayLike(value) && (isArray(value) || isArguments$1(value))) {
|
||||
// Flatten current level of array or arguments object.
|
||||
if (depth > 1) {
|
||||
flatten(value, depth - 1, strict, output);
|
||||
flatten$1(value, depth - 1, strict, output);
|
||||
idx = output.length;
|
||||
} else {
|
||||
var j = 0, len = value.length;
|
||||
@@ -1042,7 +1057,7 @@
|
||||
// are the method names to be bound. Useful for ensuring that all callbacks
|
||||
// defined on an object belong to it.
|
||||
var bindAll = restArguments(function(obj, keys) {
|
||||
keys = flatten(keys, false, false);
|
||||
keys = flatten$1(keys, false, false);
|
||||
var index = keys.length;
|
||||
if (index < 1) throw new Error('bindAll must be passed function names');
|
||||
while (index--) {
|
||||
@@ -1057,7 +1072,7 @@
|
||||
var memoize = function(key) {
|
||||
var cache = memoize.cache;
|
||||
var address = '' + (hasher ? hasher.apply(this, arguments) : key);
|
||||
if (!has(cache, address)) cache[address] = func.apply(this, arguments);
|
||||
if (!has$1(cache, address)) cache[address] = func.apply(this, arguments);
|
||||
return cache[address];
|
||||
};
|
||||
memoize.cache = {};
|
||||
@@ -1074,7 +1089,7 @@
|
||||
|
||||
// Defers a function, scheduling it to run after the current call stack has
|
||||
// cleared.
|
||||
var defer = partial(delay, _, 1);
|
||||
var defer = partial(delay, _$1, 1);
|
||||
|
||||
// Returns a function, that, when invoked, will only be triggered at most once
|
||||
// during a given window of time. Normally, the throttled function will run
|
||||
@@ -1420,7 +1435,7 @@
|
||||
if (isFunction$1(path)) {
|
||||
func = path;
|
||||
} else {
|
||||
path = toPath$1(path);
|
||||
path = toPath(path);
|
||||
contextPath = path.slice(0, -1);
|
||||
path = path[path.length - 1];
|
||||
}
|
||||
@@ -1562,7 +1577,7 @@
|
||||
// Groups the object's values by a criterion. Pass either a string attribute
|
||||
// to group by, or a function that returns the criterion.
|
||||
var groupBy = group(function(result, value, key) {
|
||||
if (has(result, key)) result[key].push(value); else result[key] = [value];
|
||||
if (has$1(result, key)) result[key].push(value); else result[key] = [value];
|
||||
});
|
||||
|
||||
// Indexes the object's values by a criterion, similar to `_.groupBy`, but for
|
||||
@@ -1575,7 +1590,7 @@
|
||||
// either a string attribute to count by, or a function that returns the
|
||||
// criterion.
|
||||
var countBy = group(function(result, value, key) {
|
||||
if (has(result, key)) result[key]++; else result[key] = 1;
|
||||
if (has$1(result, key)) result[key]++; else result[key] = 1;
|
||||
});
|
||||
|
||||
// Split a collection into two arrays: one whose elements all pass the given
|
||||
@@ -1618,7 +1633,7 @@
|
||||
keys = allKeys(obj);
|
||||
} else {
|
||||
iteratee = keyInObj;
|
||||
keys = flatten(keys, false, false);
|
||||
keys = flatten$1(keys, false, false);
|
||||
obj = Object(obj);
|
||||
}
|
||||
for (var i = 0, length = keys.length; i < length; i++) {
|
||||
@@ -1636,7 +1651,7 @@
|
||||
iteratee = negate(iteratee);
|
||||
if (keys.length > 1) context = keys[1];
|
||||
} else {
|
||||
keys = map(flatten(keys, false, false), String);
|
||||
keys = map(flatten$1(keys, false, false), String);
|
||||
iteratee = function(value, key) {
|
||||
return !contains(keys, key);
|
||||
};
|
||||
@@ -1681,14 +1696,14 @@
|
||||
|
||||
// Flatten out an array, either recursively (by default), or up to `depth`.
|
||||
// Passing `true` or `false` as `depth` means `1` or `Infinity`, respectively.
|
||||
function flatten$1(array, depth) {
|
||||
return flatten(array, depth, false);
|
||||
function flatten(array, depth) {
|
||||
return flatten$1(array, depth, false);
|
||||
}
|
||||
|
||||
// Take the difference between one array and a number of other arrays.
|
||||
// Only the elements present in just the first array will remain.
|
||||
var difference = restArguments(function(array, rest) {
|
||||
rest = flatten(rest, true, true);
|
||||
rest = flatten$1(rest, true, true);
|
||||
return filter(array, function(value){
|
||||
return !contains(rest, value);
|
||||
});
|
||||
@@ -1734,7 +1749,7 @@
|
||||
// Produce an array that contains the union: each distinct element from all of
|
||||
// the passed-in arrays.
|
||||
var union = restArguments(function(arrays) {
|
||||
return uniq(flatten(arrays, true, true));
|
||||
return uniq(flatten$1(arrays, true, true));
|
||||
});
|
||||
|
||||
// Produce an array that contains every item shared between all the
|
||||
@@ -1821,26 +1836,26 @@
|
||||
|
||||
// Helper function to continue chaining intermediate results.
|
||||
function chainResult(instance, obj) {
|
||||
return instance._chain ? _(obj).chain() : obj;
|
||||
return instance._chain ? _$1(obj).chain() : obj;
|
||||
}
|
||||
|
||||
// Add your own custom functions to the Underscore object.
|
||||
function mixin(obj) {
|
||||
each(functions(obj), function(name) {
|
||||
var func = _[name] = obj[name];
|
||||
_.prototype[name] = function() {
|
||||
var func = _$1[name] = obj[name];
|
||||
_$1.prototype[name] = function() {
|
||||
var args = [this._wrapped];
|
||||
push.apply(args, arguments);
|
||||
return chainResult(this, func.apply(_, args));
|
||||
return chainResult(this, func.apply(_$1, args));
|
||||
};
|
||||
});
|
||||
return _;
|
||||
return _$1;
|
||||
}
|
||||
|
||||
// Add all mutator `Array` functions to the wrapper.
|
||||
each(['pop', 'push', 'reverse', 'shift', 'sort', 'splice', 'unshift'], function(name) {
|
||||
var method = ArrayProto[name];
|
||||
_.prototype[name] = function() {
|
||||
_$1.prototype[name] = function() {
|
||||
var obj = this._wrapped;
|
||||
if (obj != null) {
|
||||
method.apply(obj, arguments);
|
||||
@@ -1855,7 +1870,7 @@
|
||||
// Add all accessor `Array` functions to the wrapper.
|
||||
each(['concat', 'join', 'slice'], function(name) {
|
||||
var method = ArrayProto[name];
|
||||
_.prototype[name] = function() {
|
||||
_$1.prototype[name] = function() {
|
||||
var obj = this._wrapped;
|
||||
if (obj != null) obj = method.apply(obj, arguments);
|
||||
return chainResult(this, obj);
|
||||
@@ -1909,12 +1924,12 @@
|
||||
clone: clone,
|
||||
tap: tap,
|
||||
get: get,
|
||||
has: has$1,
|
||||
has: has,
|
||||
mapObject: mapObject,
|
||||
identity: identity,
|
||||
constant: constant,
|
||||
noop: noop,
|
||||
toPath: toPath,
|
||||
toPath: toPath$1,
|
||||
property: property,
|
||||
propertyOf: propertyOf,
|
||||
matcher: matcher,
|
||||
@@ -1997,7 +2012,7 @@
|
||||
tail: rest,
|
||||
drop: rest,
|
||||
compact: compact,
|
||||
flatten: flatten$1,
|
||||
flatten: flatten,
|
||||
without: without,
|
||||
uniq: uniq,
|
||||
unique: uniq,
|
||||
@@ -2011,17 +2026,17 @@
|
||||
range: range,
|
||||
chunk: chunk,
|
||||
mixin: mixin,
|
||||
'default': _
|
||||
'default': _$1
|
||||
};
|
||||
|
||||
// Default Export
|
||||
|
||||
// Add all of the Underscore functions to the wrapper object.
|
||||
var _$1 = mixin(allExports);
|
||||
var _ = mixin(allExports);
|
||||
// Legacy Node.js API.
|
||||
_$1._ = _$1;
|
||||
_._ = _;
|
||||
|
||||
return _$1;
|
||||
return _;
|
||||
|
||||
})));
|
||||
//# sourceMappingURL=underscore.js.map
|
||||
//# sourceMappingURL=underscore-umd.js.map
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
|
||||
<script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
@@ -87,18 +90,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
@@ -89,7 +92,7 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Installation</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#binary-distributions">Binary Distributions</a></li>
|
||||
@@ -102,13 +105,13 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,7 +92,7 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
@@ -103,13 +106,13 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -191,9 +194,9 @@ to download the full example code</p>
|
||||
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The basic programming model used by Triton</p></li>
|
||||
<li><p>The <cite>triton.jit</cite> decorator, which constitutes the main entry point for writing Triton kernels.</p></li>
|
||||
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
|
||||
<li><p>The basic programming model of Triton</p></li>
|
||||
<li><p>The <cite>triton.jit</cite> decorator, which is used to define Triton kernels.</p></li>
|
||||
<li><p>The best practices for validating and benchmarking your custom ops against native reference implementations</p></li>
|
||||
</ul>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -225,25 +228,25 @@ to download the full example code</p>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can also declara a helper function that handles allocating the output vector
|
||||
and enqueueing the kernel.</p>
|
||||
<p>Let’s also declare a helper function that to (1) allocate the output vector
|
||||
and (2) enqueueing the above kernel.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that should execute in parallel.</span>
|
||||
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that run in parallel.</span>
|
||||
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]),</span> <span class="p">)</span>
|
||||
<span class="c1"># NOTE:</span>
|
||||
<span class="c1"># - torch.tensor objects are implicitly converted to pointers to their first element.</span>
|
||||
<span class="c1"># - `triton.jit`'ed functions can be subscripted with a launch grid to obtain a callable GPU kernel</span>
|
||||
<span class="c1"># - each torch.tensor object is implicitly converted into a pointer to its first element.</span>
|
||||
<span class="c1"># - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel</span>
|
||||
<span class="c1"># - don't forget to pass meta-parameters as keywords arguments</span>
|
||||
<span class="n">_add</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
|
||||
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still</span>
|
||||
<span class="c1"># running asynchronously.</span>
|
||||
<span class="c1"># running asynchronously at this point.</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects and test our results:</p>
|
||||
<p>We can now use the above function to compute the element-wise sum of two <cite>torch.tensor</cite> objects and test its correctness:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
@@ -266,7 +269,7 @@ The maximum difference between torch and triton is 0.0
|
||||
<div class="section" id="benchmark">
|
||||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
|
||||
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
|
||||
for different problem sizes.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
@@ -294,11 +297,32 @@ for different problem sizes.</p>
|
||||
</div>
|
||||
<p>We can now run the decorated function above. Pass <cite>show_plots=True</cite> to see the plots and/or
|
||||
<a href="#id1"><span class="problematic" id="id2">`</span></a>save_path=’/path/to/results/’ to save them to disk along with raw CSV data</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 7.682 seconds)</p>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>vector-add-performance:
|
||||
size Torch Triton
|
||||
0 4096.0 9.600000 9.600000
|
||||
1 8192.0 19.200000 19.200000
|
||||
2 16384.0 38.400001 38.400001
|
||||
3 32768.0 76.800002 63.999998
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
10 4194304.0 780.190482 780.190482
|
||||
11 8388608.0 812.429770 812.429770
|
||||
12 16777216.0 833.084721 833.084721
|
||||
13 33554432.0 843.811163 843.811163
|
||||
14 67108864.0 848.362445 849.278610
|
||||
15 134217728.0 850.656574 851.577704
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 11.005 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,10 +36,11 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
@@ -90,7 +93,7 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
@@ -106,13 +109,13 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,10 +195,11 @@ to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="fused-softmax">
|
||||
<span id="sphx-glr-getting-started-tutorials-02-fused-softmax-py"></span><h1>Fused Softmax<a class="headerlink" href="#fused-softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:</p>
|
||||
<p>In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
|
||||
You will learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||||
<li><p>The reduction operators in Triton.</p></li>
|
||||
<li><p>Reduction operators in Triton.</p></li>
|
||||
</ul>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -205,15 +209,16 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
|
||||
|
||||
<span class="c1"># Compute the row-wise softmax of x</span>
|
||||
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">script</span>
|
||||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># read MN elements ; write MN elements</span>
|
||||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">numerator</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
|
||||
@@ -222,14 +227,14 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip.
|
||||
This solution would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.</p>
|
||||
Doing so would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
The <cite>torch.jit.script</cite> flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.</p>
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally “pad” tiles and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
@@ -239,6 +244,7 @@ so we need to internally “pad” tiles and guard the memory operations properl
|
||||
<span class="c1"># row index</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># col indices</span>
|
||||
<span class="c1"># here BLOCK is the smallest power of two greater than `N`</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># the memory address of all the elements</span>
|
||||
<span class="c1"># that we want to load can be computed as follows</span>
|
||||
@@ -272,11 +278,10 @@ so we need to internally “pad” tiles and guard the memory operations properl
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="c1"># Another trick we can use is to ask the compiler to parallelize each</span>
|
||||
<span class="c1"># row-normalization more aggressively -- i.e., with more warps -- vectors</span>
|
||||
<span class="c1"># that are longer</span>
|
||||
<span class="c1"># Another trick we can use is to ask the compiler to use more threads per row by</span>
|
||||
<span class="c1"># increasing the number of warps (`num_warps`) over which each row is distributed.</span>
|
||||
<span class="c1"># You will see in the next tutorial how to auto-tune this value in a more natural</span>
|
||||
<span class="c1"># way so you don't have to come up with manual heuristics yourself</span>
|
||||
<span class="c1"># way so you don't have to come up with manual heuristics yourself.</span>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
@@ -312,10 +317,11 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'naive'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'Naive'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'torch-native'</span><span class="p">,</span> <span class="s1">'torch-jit'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Torch (native)"</span><span class="p">,</span> <span class="s2">"Torch (jit)"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"softmax-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
@@ -323,30 +329,48 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
<span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch-native'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'naive'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch-jit'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="02 fused softmax" class="sphx-glr-single-img" src="../../_images/sphx_glr_02-fused-softmax_001.png" />
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 512.000001 546.133347 273.066674
|
||||
1 384.0 585.142862 585.142862 267.130429
|
||||
2 512.0 630.153853 606.814814 264.258068
|
||||
3 640.0 682.666684 640.000002 269.473696
|
||||
4 768.0 702.171410 664.216187 273.066663
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 405.755985 329.483481
|
||||
94 12288.0 812.429770 415.661740 329.602681
|
||||
95 12416.0 810.840807 412.149375 329.173158
|
||||
96 12544.0 810.925276 412.971190 329.292871
|
||||
97 12672.0 811.007961 412.097543 329.142870
|
||||
|
||||
[98 rows x 4 columns]
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>In the above plot, we can see that:</p>
|
||||
<blockquote>
|
||||
<div><ul class="simple">
|
||||
<li><p>Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.</p></li>
|
||||
<li><p>Triton is significantly faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code> for very large input matrices. My guess from looking at the source-code of the <a class="reference external" href="https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240">PyTorch kernel</a> is that PyTorch only partially fuses the computation of the softmax.
|
||||
This means that – when temporary data is too large to fit entirely in the GPU’s cache – it transfers almost twice the amount of data necessary.
|
||||
<li><p>Triton is 2-3x faster than the Torch JIT.</p></li>
|
||||
<li><p>Triton is even faster than <code class="code docutils literal notranslate"><span class="pre">torch.softmax</span></code>. My guess from looking at the source-code of the <a class="reference external" href="https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240">PyTorch kernel</a> is that PyTorch only partially fuses the computation of the softmax.
|
||||
This means that – when temporary data is too large to fit entirely in the GPU’s cache – it transfers almost twice the amount of memory necessary.
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 20.250 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 8.184 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,7 +92,7 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
@@ -113,13 +116,13 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -199,7 +202,7 @@ to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="matrix-multiplication">
|
||||
<span id="sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"></span><h1>Matrix Multiplication<a class="headerlink" href="#matrix-multiplication" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a 25-lines high-performance matrix multiplication kernel that achieves close to peak performance on modern GPUs.
|
||||
<p>In this tutorial, you will write a 25-lines high-performance FP16 matrix multiplication kernel that achieves performance on par with cuBLAS.
|
||||
You will specifically learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>Block-level matrix multiplications</p></li>
|
||||
@@ -210,9 +213,9 @@ You will specifically learn about:</p>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Matrix multiplications are a key building block of most modern high-performance computing systems.
|
||||
They are notoriously hard to optimize, hence their implementation is typically done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., mixture of experts, fused activation functions, etc.).
|
||||
For this reason, this tutorial will show you how to implement efficient matrix multiplications yourself with Triton, in a way that is easy to customize and extend.</p>
|
||||
They are notoriously hard to optimize, hence their implementation is generally done by hardware vendors themselves as part of so-called “kernel libraries” (e.g., cuBLAS).
|
||||
Unfortunately, these libraries are often proprietary and cannot be easily customized to accomodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
||||
In this tutorial, you will learn how to implement efficient matrix multiplications by yourself with Triton, in a way that is easy to customize and extend.</p>
|
||||
<p>Roughly speaking, the kernel that we will write will implement the following blocked algorithm:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||||
@@ -232,19 +235,19 @@ For this reason, this tutorial will show you how to implement efficient matrix m
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The above algorithm is actually fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the 2D pointer arithmetic that must be done to specify the memory locations for the blocks of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> that we need to read in the inner loop.</p>
|
||||
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
||||
The main difficulty comes from the computation of the memory locations at which blocks of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need multi-dimensional pointer arithmetics.</p>
|
||||
<div class="section" id="pointer-arithmetics">
|
||||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given by <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_x_0</span> <span class="pre">+</span> <span class="pre">j*stride_x_1</span></code>.
|
||||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_M,</span> <span class="pre">k:k+BLOCK_K]</span></code> and <code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_N]</span></code> can be defined in pseudo-code as:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:];</span>
|
||||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:];</span>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>Which means that, at initialization (i.e., <code class="code docutils literal notranslate"><span class="pre">k</span> <span class="pre">=</span> <span class="pre">0</span></code>), pointers for blocks of A and B can be initialized in Triton as:</p>
|
||||
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid_m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||||
@@ -258,7 +261,7 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>These pointers can then be updated in the inner loop as:</p>
|
||||
<p>And then updated in the inner loop as follows:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">;</span>
|
||||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
||||
@@ -269,8 +272,8 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
||||
<div class="section" id="l2-cache-optimizations">
|
||||
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
||||
<p>As mentioned above, each program instance computes an <code class="code docutils literal notranslate"><span class="pre">[BLOCK_M,</span> <span class="pre">BLOCK_N]</span></code> block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
||||
However, the order in which these blocks are computer matters, since it affects the L2 cache hit rate of our program.
|
||||
This means that a naive row-major ordering:</p>
|
||||
It is important to remember that the order in which these blocks are computed does matter, since it affects the L2 cache hit rate of our program.
|
||||
And unfortunately, a simple row-major ordering</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span><span class="p">;</span>
|
||||
@@ -280,7 +283,7 @@ This means that a naive row-major ordering:</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>is unlikely to result in optimal performance.</p>
|
||||
<p>is just not going to cut it.</p>
|
||||
<p>One possible solution is to launch blocks in an order that promotes data reuse.
|
||||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before switching to the next column:</p>
|
||||
<blockquote>
|
||||
@@ -308,23 +311,19 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="c1"># - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try</span>
|
||||
<span class="c1"># - A autotuning *key* whose change in values will trigger evaluation of all the provided configs</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">ret_true</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="n">ret_false</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">ret_true</span><span class="p">,</span> <span class="n">ret_false</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">swish</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">autotune</span><span class="p">(</span>
|
||||
<span class="n">configs</span><span class="o">=</span><span class="p">[</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">256</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>\
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">Config</span><span class="p">({</span><span class="s1">'BLOCK_M'</span><span class="p">:</span> <span class="mi">32</span> <span class="p">,</span> <span class="s1">'BLOCK_N'</span><span class="p">:</span> <span class="mi">64</span> <span class="p">,</span> <span class="s1">'BLOCK_K'</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="s1">'GROUP_M'</span><span class="p">:</span> <span class="mi">8</span><span class="p">},</span> <span class="n">num_stages</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
||||
<span class="c1">#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),</span>
|
||||
<span class="p">],</span>
|
||||
<span class="n">key</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span>
|
||||
<span class="p">)</span>
|
||||
@@ -370,10 +369,16 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="n">C</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_cm</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_cn</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">acc</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="mf">0.01</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can also create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the kernel</p>
|
||||
<p>We can now create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="c1"># checks constraints</span>
|
||||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"incompatible dimensions"</span>
|
||||
@@ -385,56 +390,46 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
<span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="c1"># launch kernel</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">META</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_N'</span><span class="p">]),</span> <span class="p">)</span>
|
||||
<span class="n">_matmul</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">pgm</span> <span class="o">=</span> <span class="n">_matmul</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> \
|
||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>\
|
||||
<span class="n">ACTIVATION</span> <span class="o">=</span> <span class="n">activation</span>
|
||||
<span class="p">)</span>
|
||||
<span class="c1"># return output</span>
|
||||
<span class="c1"># done; return the output tensor</span>
|
||||
<span class="k">return</span> <span class="n">c</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1">#torch.manual_seed(0)</span>
|
||||
<p>We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">c_0</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">swish</span><span class="p">)</span>
|
||||
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
<span class="n">c_0</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">c_1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">c_0</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">c_1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">c_0</span><span class="p">,</span> <span class="n">c_1</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[-4.5061e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3251e-03, -0.0000e+00],
|
||||
[-1.0967e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4603e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438],
|
||||
[ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867],
|
||||
[-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328],
|
||||
...,
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5215e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7285e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.5993e-03]], device='cuda:0', dtype=torch.float16)
|
||||
tensor([[-4.4942e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3232e-03, -0.0000e+00],
|
||||
[-1.1003e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4639e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
[ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438],
|
||||
[ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828],
|
||||
[-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328],
|
||||
...,
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5276e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7345e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.6031e-03]], device='cuda:0', dtype=torch.float16)
|
||||
[ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500],
|
||||
[ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219],
|
||||
[ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -443,27 +438,32 @@ tensor(True, device='cuda:0')
|
||||
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="section" id="square-matrix-performance">
|
||||
<h3>Square Matrix Performance<a class="headerlink" href="#square-matrix-performance" title="Permalink to this headline">¶</a></h3>
|
||||
<p>We can now compare the performance of our kernel against CUTLASS. Here we focus on square matrices, but feel free to arrange the script as you wish to compare any other matrix shape.#</p>
|
||||
<p>We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'cublas + relu'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'triton + relu'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"cuBLAS (+ torch.nn.LeakyReLU)"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s2">"Triton (+ LeakyReLU)"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">styles</span><span class="o">=</span><span class="p">[(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'green'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'-'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'blue'</span><span class="p">,</span> <span class="s1">'--'</span><span class="p">)],</span> <span class="c1"># line styles</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
||||
<span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
||||
<span class="n">silu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SiLU</span><span class="p">()</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas + relu'</span><span class="p">:</span>
|
||||
<span class="n">torch_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch_relu</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)))</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton + relu'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">))</span>
|
||||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||||
|
||||
@@ -473,41 +473,43 @@ tensor(True, device='cuda:0')
|
||||
</div>
|
||||
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span> M cuBLAS Triton
|
||||
0 512.0 20.164923 15.420235
|
||||
1 768.0 58.982401 40.215272
|
||||
2 1024.0 95.325090 72.315584
|
||||
3 1280.0 151.703703 117.028568
|
||||
4 1536.0 153.867127 150.593357
|
||||
5 1792.0 208.137481 190.498706
|
||||
6 2048.0 202.135135 151.146088
|
||||
7 2304.0 251.451276 178.267699
|
||||
8 2560.0 237.449270 218.453323
|
||||
9 2816.0 238.329010 200.987140
|
||||
10 3072.0 243.017615 223.806730
|
||||
11 3328.0 244.868356 210.500857
|
||||
12 3584.0 250.460703 232.941430
|
||||
13 3840.0 256.593972 225.697957
|
||||
14 4096.0 266.305018 247.634187
|
||||
15 4352.0 247.675667 237.797917
|
||||
16 4608.0 280.621108 260.713476
|
||||
17 4864.0 272.431168 252.534501
|
||||
18 5120.0 265.596772 245.223576
|
||||
19 5376.0 261.381955 244.335299
|
||||
20 5632.0 283.439220 260.383339
|
||||
21 5888.0 276.674704 254.103421
|
||||
22 6144.0 274.869441 252.078378
|
||||
23 6400.0 269.190319 249.027231
|
||||
24 6656.0 269.252160 249.104840
|
||||
25 6912.0 267.069377 247.115909
|
||||
26 7168.0 268.504352 246.006552
|
||||
27 7424.0 267.373291 246.355964
|
||||
28 7680.0 266.406511 245.760004
|
||||
29 7936.0 228.348876 248.331598
|
||||
30 8192.0 227.680622 247.977332
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||||
M cuBLAS cuBLAS (+ torch.nn.LeakyReLU) Triton Triton (+ LeakyReLU)
|
||||
0 128.0 0.455111 0.372364 0.512000 0.512000
|
||||
1 256.0 2.978909 2.340571 3.276800 2.978909
|
||||
2 384.0 7.372800 6.144000 8.507077 8.507077
|
||||
3 512.0 14.563555 11.915636 16.384000 16.384000
|
||||
4 640.0 22.260869 18.285714 23.272727 23.272727
|
||||
5 768.0 32.768000 26.810182 34.028308 34.028308
|
||||
6 896.0 39.025776 32.672744 39.025776 39.025776
|
||||
7 1024.0 49.932191 41.943041 52.428801 52.428801
|
||||
8 1152.0 44.566925 38.779015 46.656000 46.656000
|
||||
9 1280.0 51.200001 44.521738 56.109587 56.109587
|
||||
10 1408.0 64.138541 55.068446 65.684049 59.258433
|
||||
11 1536.0 79.526831 67.408458 75.296679 75.296679
|
||||
12 1664.0 63.372618 55.893862 61.636381 61.636381
|
||||
13 1792.0 72.983276 63.860363 68.953520 68.953520
|
||||
14 1920.0 66.782607 61.168141 68.776119 68.776119
|
||||
15 2048.0 73.262953 65.793006 75.234154 75.234154
|
||||
16 2176.0 82.473969 73.712993 79.540109 79.855747
|
||||
17 2304.0 68.251065 62.207998 73.051599 73.051599
|
||||
18 2432.0 71.305746 65.033481 80.963875 80.963875
|
||||
19 2560.0 77.649287 70.773218 76.560748 75.851852
|
||||
20 2688.0 82.463163 75.413632 82.106182 80.880718
|
||||
21 2816.0 82.602666 73.424595 78.442822 77.330158
|
||||
22 2944.0 82.784108 72.966370 80.122235 80.122235
|
||||
23 3072.0 79.638683 74.997490 79.082550 82.903517
|
||||
24 3200.0 84.099871 78.335374 89.385477 85.333333
|
||||
25 3328.0 83.226931 77.828428 81.346098 81.530349
|
||||
26 3456.0 79.351933 75.276907 82.858753 81.435930
|
||||
27 3584.0 87.466332 81.518940 95.858629 91.470385
|
||||
28 3712.0 84.230479 79.283603 81.682211 85.455380
|
||||
29 3840.0 84.421376 79.562590 87.355452 87.562949
|
||||
30 3968.0 93.006050 86.296981 84.038524 84.504108
|
||||
31 4096.0 93.662059 87.381330 83.729089 92.119235
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 37.657 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 12.630 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,7 +92,7 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorials</a><ul>
|
||||
@@ -99,13 +102,13 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -179,13 +182,13 @@
|
||||
<div class="section" id="tutorials">
|
||||
<span id="sphx-glr-getting-started-tutorials"></span><h1>Tutorials<a class="headerlink" href="#tutorials" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.</p>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model used by Triton - The triton.jit decorator, which constitutes the ..."><div class="figure align-default" id="id1">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic programming model of Triton - The triton.jit decorator, which is used to define Tri..."><div class="figure align-default" id="id1">
|
||||
<img alt="Vector Addition" src="../../_images/sphx_glr_01-vector-add_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a></span><a class="headerlink" href="#id1" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The reduction operators in Tr..."><div class="figure align-default" id="id2">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - Reduction operators in Triton..."><div class="figure align-default" id="id2">
|
||||
<img alt="Fused Softmax" src="../../_images/sphx_glr_02-fused-softmax_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a></span><a class="headerlink" href="#id2" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -87,18 +90,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -171,7 +174,7 @@
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>00:37.657</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<p><strong>03:31.819</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 85%" />
|
||||
@@ -180,15 +183,15 @@
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||
<td><p>00:37.657</p></td>
|
||||
<td><p>02:12.630</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:00.000</p></td>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>01:08.184</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>00:00.000</p></td>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:11.005</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
|
||||
<script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
@@ -88,18 +91,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,18 +92,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Introduction</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,10 +36,11 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
@@ -89,18 +92,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Related Work</a><ul>
|
||||
@@ -287,39 +290,24 @@ i & j
|
||||
<div class="section" id="scheduling-languages">
|
||||
<h2>Scheduling Languages<a class="headerlink" href="#scheduling-languages" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Separation of concerns cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a <strong>scheduling language</strong>. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.</p>
|
||||
<div class="highlight-python notranslate"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
|
||||
<span class="normal"> 2</span>
|
||||
<span class="normal"> 3</span>
|
||||
<span class="normal"> 4</span>
|
||||
<span class="normal"> 5</span>
|
||||
<span class="normal"> 6</span>
|
||||
<span class="normal"> 7</span>
|
||||
<span class="normal"> 8</span>
|
||||
<span class="normal"> 9</span>
|
||||
<span class="normal">10</span>
|
||||
<span class="normal">11</span>
|
||||
<span class="normal">12</span>
|
||||
<span class="normal">13</span>
|
||||
<span class="normal">14</span>
|
||||
<span class="normal">15</span>
|
||||
<span class="normal">16</span></pre></div></td><td class="code"><div class="highlight"><pre><span></span><span class="o">//</span> <span class="n">algorithm</span>
|
||||
<span class="n">Var</span> <span class="n">x</span><span class="p">(</span><span class="s2">"x"</span><span class="p">),</span> <span class="n">y</span><span class="p">(</span><span class="s2">"y"</span><span class="p">);</span>
|
||||
<span class="n">Func</span> <span class="n">matmul</span><span class="p">(</span><span class="s2">"matmul"</span><span class="p">);</span>
|
||||
<span class="n">RDom</span> <span class="n">k</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">matrix_size</span><span class="p">);</span>
|
||||
<span class="n">RVar</span> <span class="n">ki</span><span class="p">;</span>
|
||||
<span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">=</span> <span class="mf">0.0</span><span class="n">f</span><span class="p">;</span>
|
||||
<span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">+=</span> <span class="n">A</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">B</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">k</span><span class="p">);</span>
|
||||
<span class="o">//</span> <span class="n">schedule</span>
|
||||
<span class="n">Var</span> <span class="n">xi</span><span class="p">(</span><span class="s2">"xi"</span><span class="p">),</span> <span class="n">xo</span><span class="p">(</span><span class="s2">"xo"</span><span class="p">),</span> <span class="n">yo</span><span class="p">(</span><span class="s2">"yo"</span><span class="p">),</span> <span class="n">yi</span><span class="p">(</span><span class="s2">"yo"</span><span class="p">),</span> <span class="n">yii</span><span class="p">(</span><span class="s2">"yii"</span><span class="p">),</span> <span class="n">xii</span><span class="p">(</span><span class="s2">"xii"</span><span class="p">);</span>
|
||||
<span class="n">matmul</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">8</span><span class="p">);</span>
|
||||
<span class="n">matmul</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">xi</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">xii</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
|
||||
<span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">yi</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">yii</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||||
<span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span>
|
||||
<span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">xii</span><span class="p">,</span> <span class="n">yii</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">xii</span><span class="p">)</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">xi</span><span class="p">)</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">yii</span><span class="p">);</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="linenos"> 1</span><span class="o">//</span> <span class="n">algorithm</span>
|
||||
<span class="linenos"> 2</span><span class="n">Var</span> <span class="n">x</span><span class="p">(</span><span class="s2">"x"</span><span class="p">),</span> <span class="n">y</span><span class="p">(</span><span class="s2">"y"</span><span class="p">);</span>
|
||||
<span class="linenos"> 3</span><span class="n">Func</span> <span class="n">matmul</span><span class="p">(</span><span class="s2">"matmul"</span><span class="p">);</span>
|
||||
<span class="linenos"> 4</span><span class="n">RDom</span> <span class="n">k</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">matrix_size</span><span class="p">);</span>
|
||||
<span class="linenos"> 5</span><span class="n">RVar</span> <span class="n">ki</span><span class="p">;</span>
|
||||
<span class="linenos"> 6</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">=</span> <span class="mf">0.0</span><span class="n">f</span><span class="p">;</span>
|
||||
<span class="linenos"> 7</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">+=</span> <span class="n">A</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">B</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">k</span><span class="p">);</span>
|
||||
<span class="linenos"> 8</span><span class="o">//</span> <span class="n">schedule</span>
|
||||
<span class="linenos"> 9</span><span class="n">Var</span> <span class="n">xi</span><span class="p">(</span><span class="s2">"xi"</span><span class="p">),</span> <span class="n">xo</span><span class="p">(</span><span class="s2">"xo"</span><span class="p">),</span> <span class="n">yo</span><span class="p">(</span><span class="s2">"yo"</span><span class="p">),</span> <span class="n">yi</span><span class="p">(</span><span class="s2">"yo"</span><span class="p">),</span> <span class="n">yii</span><span class="p">(</span><span class="s2">"yii"</span><span class="p">),</span> <span class="n">xii</span><span class="p">(</span><span class="s2">"xii"</span><span class="p">);</span>
|
||||
<span class="linenos">10</span><span class="n">matmul</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">8</span><span class="p">);</span>
|
||||
<span class="linenos">11</span><span class="n">matmul</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="linenos">12</span> <span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">xi</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">xii</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
|
||||
<span class="linenos">13</span> <span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">yi</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">yii</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||||
<span class="linenos">14</span> <span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)</span>
|
||||
<span class="linenos">15</span> <span class="o">.</span><span class="n">reorder</span><span class="p">(</span><span class="n">xii</span><span class="p">,</span> <span class="n">yii</span><span class="p">,</span> <span class="n">xi</span><span class="p">,</span> <span class="n">ki</span><span class="p">,</span> <span class="n">yi</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="linenos">16</span> <span class="o">.</span><span class="n">parallel</span><span class="p">(</span><span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="n">xii</span><span class="p">)</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">xi</span><span class="p">)</span><span class="o">.</span><span class="n">unroll</span><span class="p">(</span><span class="n">yii</span><span class="p">);</span>
|
||||
</pre></div>
|
||||
</td></tr></table></div>
|
||||
</div>
|
||||
<p>The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms <a class="reference internal" href="#mullapudi2016" id="id14"><span>[MULLAPUDI2016]</span></a>.</p>
|
||||
<div class="section" id="id15">
|
||||
<h3>Advantages<a class="headerlink" href="#id15" title="Permalink to this headline">¶</a></h3>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.html">triton</a><ul class="current">
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">triton.jit</a></li>
|
||||
@@ -103,7 +106,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -179,8 +182,8 @@
|
||||
<div class="section" id="triton-jit">
|
||||
<h1>triton.jit<a class="headerlink" href="#triton-jit" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.jit">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.</span></code><code class="sig-name descname"><span class="pre">jit</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.jit" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.jit">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.</span></span><span class="sig-name descname"><span class="pre">jit</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.jit" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Decorator for JIT-compiling a function using the Triton compiler.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Note</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-arange">
|
||||
<h1>triton.language.arange<a class="headerlink" href="#triton-language-arange" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.arange">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">arange</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">start</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">end</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.arange" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.arange">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">arange</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">start</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">end</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.arange" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns contiguous values within the open interval [<code class="code docutils literal notranslate"><span class="pre">start</span></code>, <code class="code docutils literal notranslate"><span class="pre">end</span></code>).</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-atomic-cas">
|
||||
<h1>triton.language.atomic_cas<a class="headerlink" href="#triton-language-atomic-cas" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.atomic_cas">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">atomic_cas</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">cmp</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.atomic_cas" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.atomic_cas">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">atomic_cas</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">cmp</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.atomic_cas" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Performs an atomic “compare-and-swap” and the memory locations specified by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-atomic-xchg">
|
||||
<h1>triton.language.atomic_xchg<a class="headerlink" href="#triton-language-atomic-xchg" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.atomic_xchg">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">atomic_xchg</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.atomic_xchg" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.atomic_xchg">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">atomic_xchg</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.atomic_xchg" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Swaps the <em>old</em> values stored at location <code class="code docutils literal notranslate"><span class="pre">pointer</span></code> with the new values given by <code class="code docutils literal notranslate"><span class="pre">val</span></code>. Returns the old values.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-broadcast-to">
|
||||
<h1>triton.language.broadcast_to<a class="headerlink" href="#triton-language-broadcast-to" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.broadcast_to">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">broadcast_to</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.broadcast_to" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.broadcast_to">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">broadcast_to</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.broadcast_to" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Tries to broadcast the given block to a new <code class="code docutils literal notranslate"><span class="pre">shape</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -115,7 +118,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -191,8 +194,8 @@
|
||||
<div class="section" id="triton-language-dot">
|
||||
<h1>triton.language.dot<a class="headerlink" href="#triton-language-dot" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.dot">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">dot</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.dot">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">dot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.dot" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the matrix product of two blocks.</p>
|
||||
<p>The two blocks must be two dimensionals and have compatible inner dimensions.</p>
|
||||
<dl class="field-list simple">
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-exp">
|
||||
<h1>triton.language.exp<a class="headerlink" href="#triton-language-exp" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.exp">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">exp</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.exp" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.exp">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">exp</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.exp" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise exponential of <code class="code docutils literal notranslate"><span class="pre">x</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-load">
|
||||
<h1>triton.language.load<a class="headerlink" href="#triton-language-load" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.load">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">load</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.load" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.load">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">load</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.load" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Return a block of data whose values are, elementwise, loaded from memory at location defined by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
|
||||
<p><code class="code docutils literal notranslate"><span class="pre">mask</span></code> and <code class="code docutils literal notranslate"><span class="pre">other</span></code> are implicitly broadcast to <code class="code docutils literal notranslate"><span class="pre">pointer.shape</span></code>.</p>
|
||||
<p><code class="code docutils literal notranslate"><span class="pre">other</span></code> is implicitly typecast to <code class="code docutils literal notranslate"><span class="pre">pointer.dtype.element_ty</span></code>.</p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-log">
|
||||
<h1>triton.language.log<a class="headerlink" href="#triton-language-log" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.log">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">log</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.log" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.log">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">log</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.log" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise natural logarithm of <code class="code docutils literal notranslate"><span class="pre">x</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-max">
|
||||
<h1>triton.language.max<a class="headerlink" href="#triton-language-max" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.max">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">max</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.max" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.max">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">max</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.max" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the maximum value of all elements in the <code class="code docutils literal notranslate"><span class="pre">input</span></code> block along the provided <code class="code docutils literal notranslate"><span class="pre">axis</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-maximum">
|
||||
<h1>triton.language.maximum<a class="headerlink" href="#triton-language-maximum" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.maximum">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">maximum</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.maximum" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.maximum">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">maximum</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.maximum" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise maximum of <code class="code docutils literal notranslate"><span class="pre">x</span></code> and <code class="code docutils literal notranslate"><span class="pre">y</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-min">
|
||||
<h1>triton.language.min<a class="headerlink" href="#triton-language-min" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.min">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">min</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.min" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.min">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">min</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.min" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the minimum value of all elements in the <code class="code docutils literal notranslate"><span class="pre">input</span></code> block along the provided <code class="code docutils literal notranslate"><span class="pre">axis</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-minimum">
|
||||
<h1>triton.language.minimum<a class="headerlink" href="#triton-language-minimum" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.minimum">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">minimum</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.minimum" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.minimum">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">minimum</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.minimum" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise minimum of <code class="code docutils literal notranslate"><span class="pre">x</span></code> and <code class="code docutils literal notranslate"><span class="pre">y</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -115,7 +118,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -191,8 +194,8 @@
|
||||
<div class="section" id="triton-language-multiple-of">
|
||||
<h1>triton.language.multiple_of<a class="headerlink" href="#triton-language-multiple-of" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.multiple_of">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">multiple_of</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.multiple_of" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.multiple_of">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">multiple_of</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.multiple_of" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Let the compiler knows that the values in <code class="code docutils literal notranslate"><span class="pre">input</span></code> are all multiples of <code class="code docutils literal notranslate"><span class="pre">value</span></code>.
|
||||
:param builder: IR builder to generate code into
|
||||
:type builder: triton.ir.builder, optional from within JIT’ed functions</p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-num-programs">
|
||||
<h1>triton.language.num_programs<a class="headerlink" href="#triton-language-num-programs" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.num_programs">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">num_programs</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.num_programs" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.num_programs">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">num_programs</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.num_programs" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the number of program instances launched along the given <code class="code docutils literal notranslate"><span class="pre">axis</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-program-id">
|
||||
<h1>triton.language.program_id<a class="headerlink" href="#triton-language-program-id" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.program_id">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">program_id</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.program_id" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.program_id">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">program_id</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.program_id" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the id of the current program instance along the given <code class="code docutils literal notranslate"><span class="pre">axis</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-ravel">
|
||||
<h1>triton.language.ravel<a class="headerlink" href="#triton-language-ravel" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.ravel">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">ravel</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.ravel" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.ravel">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">ravel</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.ravel" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns a contiguous flattened view of <code class="code docutils literal notranslate"><span class="pre">x</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-reshape">
|
||||
<h1>triton.language.reshape<a class="headerlink" href="#triton-language-reshape" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.reshape">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">reshape</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.reshape" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.reshape">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">reshape</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.reshape" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Tries to reshape the given block to a new shape.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-sigmoid">
|
||||
<h1>triton.language.sigmoid<a class="headerlink" href="#triton-language-sigmoid" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.sigmoid">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">sigmoid</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.sigmoid" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.sigmoid">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">sigmoid</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.sigmoid" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise sigmoid of <code class="code docutils literal notranslate"><span class="pre">x</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-softmax">
|
||||
<h1>triton.language.softmax<a class="headerlink" href="#triton-language-softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.softmax">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">softmax</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.softmax" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.softmax">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">softmax</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.softmax" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Computes the element-wise softmax of <code class="code docutils literal notranslate"><span class="pre">x</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -118,7 +121,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -194,8 +197,8 @@
|
||||
<div class="section" id="triton-language-store">
|
||||
<h1>triton.language.store<a class="headerlink" href="#triton-language-store" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.store">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">store</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.store">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">store</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pointer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.store" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Stores <code class="code docutils literal notranslate"><span class="pre">value</span></code> block of elements in memory, element-wise, at the memory locations specified by <code class="code docutils literal notranslate"><span class="pre">pointer</span></code>.</p>
|
||||
<p><code class="code docutils literal notranslate"><span class="pre">value</span></code> is implicitly broadcast to <code class="code docutils literal notranslate"><span class="pre">pointer.shape</span></code> and typecast to <code class="code docutils literal notranslate"><span class="pre">pointer.dtype.element_ty</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -117,7 +120,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -193,8 +196,8 @@
|
||||
<div class="section" id="triton-language-sum">
|
||||
<h1>triton.language.sum<a class="headerlink" href="#triton-language-sum" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.sum">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">sum</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.sum" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.sum">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">sum</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.sum" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns the sum of all elements in the <code class="code docutils literal notranslate"><span class="pre">input</span></code> block along the provided <code class="code docutils literal notranslate"><span class="pre">axis</span></code></p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -115,7 +118,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -191,8 +194,8 @@
|
||||
<div class="section" id="triton-language-where">
|
||||
<h1>triton.language.where<a class="headerlink" href="#triton-language-where" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.where">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">where</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">condition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.where" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.where">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">where</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">condition</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.where" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns a block of elements from either <code class="code docutils literal notranslate"><span class="pre">x</span></code> or <code class="code docutils literal notranslate"><span class="pre">y</span></code>, depending on <code class="code docutils literal notranslate"><span class="pre">condition</span></code>.</p>
|
||||
<p>Note that <code class="code docutils literal notranslate"><span class="pre">x</span></code> and <code class="code docutils literal notranslate"><span class="pre">y</span></code> are always evaluated regardless of the value of <code class="code docutils literal notranslate"><span class="pre">condition</span></code>.</p>
|
||||
<p>If you want to avoid unintented memory operations, use the <code class="code docutils literal notranslate"><span class="pre">mask</span></code> arguments in <cite>triton.load</cite> and <cite>triton.store</cite> instead.</p>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="../triton.language.html">triton.language</a><ul class="current">
|
||||
@@ -116,7 +119,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -192,8 +195,8 @@
|
||||
<div class="section" id="triton-language-zeros">
|
||||
<h1>triton.language.zeros<a class="headerlink" href="#triton-language-zeros" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.language.zeros">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.language.</span></code><code class="sig-name descname"><span class="pre">zeros</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.zeros" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.language.zeros">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.language.</span></span><span class="sig-name descname"><span class="pre">zeros</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">builder</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.language.zeros" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Returns a block filled with the scalar value 0 for the given <code class="code docutils literal notranslate"><span class="pre">shape</span></code> and <code class="code docutils literal notranslate"><span class="pre">dtype</span></code>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.language.html">triton.language</a></li>
|
||||
@@ -105,7 +108,7 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -181,12 +184,12 @@
|
||||
<div class="section" id="triton-testing-benchmark">
|
||||
<h1>triton.testing.Benchmark<a class="headerlink" href="#triton-testing-benchmark" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py class">
|
||||
<dt id="triton.testing.Benchmark">
|
||||
<em class="property"><span class="pre">class</span> </em><code class="sig-prename descclassname"><span class="pre">triton.testing.</span></code><code class="sig-name descname"><span class="pre">Benchmark</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_arg</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">plot_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">xlabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ylabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.Benchmark" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.testing.Benchmark">
|
||||
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">triton.testing.</span></span><span class="sig-name descname"><span class="pre">Benchmark</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_arg</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">plot_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">xlabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ylabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">color</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">styles</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.Benchmark" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>This class is used by the <code class="code docutils literal notranslate"><span class="pre">perf_report</span></code> function to generate line plots with a concise API.</p>
|
||||
<dl class="py method">
|
||||
<dt id="triton.testing.Benchmark.__init__">
|
||||
<code class="sig-name descname"><span class="pre">__init__</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_arg</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">plot_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">xlabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ylabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.Benchmark.__init__" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.testing.Benchmark.__init__">
|
||||
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_arg</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_vals</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">line_names</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">plot_name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">xlabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ylabel</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">x_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">y_log</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">color</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">styles</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.Benchmark.__init__" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Constructor</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.language.html">triton.language</a></li>
|
||||
@@ -105,7 +108,7 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -181,8 +184,8 @@
|
||||
<div class="section" id="triton-testing-do-bench">
|
||||
<h1>triton.testing.do_bench<a class="headerlink" href="#triton-testing-do-bench" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.testing.do_bench">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.testing.</span></code><code class="sig-name descname"><span class="pre">do_bench</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">warmup</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">25</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rep</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_to_none</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">percentiles</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0.2,</span> <span class="pre">0.8]</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.do_bench" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.testing.do_bench">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.testing.</span></span><span class="sig-name descname"><span class="pre">do_bench</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">warmup</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">25</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rep</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_to_none</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">percentiles</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[0.2,</span> <span class="pre">0.8]</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.do_bench" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Benchmark the runtime of the provided function. By default, return the median runtime of <code class="code docutils literal notranslate"><span class="pre">fn</span></code> along with
|
||||
the 20-th and 80-th performance percentile.</p>
|
||||
<dl class="field-list simple">
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../triton.language.html">triton.language</a></li>
|
||||
@@ -105,7 +108,7 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
@@ -181,8 +184,8 @@
|
||||
<div class="section" id="triton-testing-perf-report">
|
||||
<h1>triton.testing.perf_report<a class="headerlink" href="#triton-testing-perf-report" title="Permalink to this headline">¶</a></h1>
|
||||
<dl class="py function">
|
||||
<dt id="triton.testing.perf_report">
|
||||
<code class="sig-prename descclassname"><span class="pre">triton.testing.</span></code><code class="sig-name descname"><span class="pre">perf_report</span></code><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">benchmarks</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.perf_report" title="Permalink to this definition">¶</a></dt>
|
||||
<dt class="sig sig-object py" id="triton.testing.perf_report">
|
||||
<span class="sig-prename descclassname"><span class="pre">triton.testing.</span></span><span class="sig-name descname"><span class="pre">perf_report</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">benchmarks</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#triton.testing.perf_report" title="Permalink to this definition">¶</a></dt>
|
||||
<dd><p>Mark a function for benchmarking. The benchmark can then be executed by using the <code class="code docutils literal notranslate"><span class="pre">.run</span></code> method on the return value.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters</dt>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">triton</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="generated/triton.jit.html">triton.jit</a></li>
|
||||
@@ -103,7 +106,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.html">triton</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">triton.language</a><ul>
|
||||
@@ -155,7 +158,7 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -34,6 +36,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
@@ -89,12 +92,12 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="triton.language.html">triton.language</a></li>
|
||||
@@ -105,7 +108,7 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
@@ -35,6 +37,7 @@
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
|
||||
<script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
@@ -90,18 +93,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<p class="caption" role="heading"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
|
||||