[GH-PAGES] Updated website
@@ -0,0 +1,136 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax layer that outperform's PyTorch implementation and learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- The syntax and usage of reduction operators in Triton.\n- The automatic vectorization capabilities of the Triton compiler.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Motivations\nCustom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.\nLet us consider instead the case of a simple (numerically stabilized) softmax operation:\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n\n\n# Compute the row-wise softmax of x\ndef naive_softmax(x):\n # read MN elements ; write M elements\n x_max = torch.max(x, axis=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(x)\n # read MN elements ; write M elements\n denominator = torch.sum(numerator, axis=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nInstead, we want to write a custom \"fused\" pytorch operators that only reads X once and does all the necessary computations on-chip.\nThis would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of 5x.\nIn practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:\n\n .. code-block:: C\n\n __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n // row index\n int m = get_program_id(0);\n // column indices\n int n [BLOCK] = 0 ... BLOCK;\n // the memory address of all the elements\n // that we want to load can be computed as follows\n float* px [BLOCK] = X + m*stride_xm + n;\n // because BLOCK has to be a power of two\n // (per Triton-C specs), it is important\n // to guard each memory operation with predicates\n // or we will read out of bounds\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n // syntax for reduction in Triton is:\n // x[..., OPERATOR, ...]\n // ^\n // index\n // The operators currently supported are {min, max, +}\n float z [BLOCK] = x - x[max];\n // The exponential in Triton is fast but approximate\n // (i.e., like __expf in CUDA)\n float num [BLOCK] = exp(z);\n float denom = num[+];\n // The result of the reduction is now stored in y\n float y [BLOCK] = num / denom;\n // We write it back\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y;\n }\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Torch Bindings\nWe need to make sure that BLOCK is the smallest power of two\ngreater than the number of rows N of the input matrix.\nDifferent values of BLOCK will result in different kernels\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton\n\n# Source code for the Triton kernel\n_src = \"\"\"\n__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n int m = get_program_id(0);\n int n [BLOCK] = 0 ... BLOCK;\n float* px [BLOCK] = X + m*stride_xm + n;\n bool check[BLOCK] = n < N;\n float x [BLOCK] = check ? *px : -F32_INFINITY;\n float z [BLOCK] = x - x[max];\n float num [BLOCK] = exp(z);\n float denom = num[+];\n float y [BLOCK] = num / denom;\n float* py [BLOCK] = Y + m*stride_ym + n;\n *?(check)py = y; \n}\n\"\"\"\n\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\n_kernels = dict()\n\n\ndef make_kernel(N, device):\n BLOCK = next_power_of_2(N)\n key = (BLOCK, device)\n if key not in _kernels:\n defines = {'BLOCK': BLOCK}\n _kernels[key] = triton.kernel(_src, device=device, defines=defines)\n return _kernels[key]\n\n\nclass _softmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n # constraints of the op\n assert x.dtype == torch.float32\n y = torch.empty_like(x)\n # *create launch grid*:\n # here we just launch a grid of M programs\n M, N = y.shape\n grid = lambda opt: (M, )\n # *launch kernel*:\n kernel = make_kernel(N, y.device)\n kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)\n return y\n\n\nsoftmax = _softmax.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Unit Test\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = torch.randn(1823, 781, device='cuda')\ny_tri = softmax(x)\ny_ref = torch.softmax(x, axis=1)\nprint(y_tri)\nprint(y_ref)\nprint(torch.allclose(y_tri, y_ref))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmarking\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n\nM = 4096\nNs = [128 * i for i in range(2, 50)]\ntri_ms = []\nref_ms = []\ndef_ms = []\nfor N in Ns:\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]\n ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]\n def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]\nplt.xlabel('N')\nplt.ylabel('Bandwidth (GB/s)')\nplt.plot(Ns, tri_ms, label='Triton')\nplt.plot(Ns, ref_ms, label='Torch')\nplt.plot(Ns, def_ms, label='Naive')\nplt.legend()\nplt.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
150
_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple, high-performance vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
#
|
||||
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
# programming model for more details).
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void add(float* z, float* x, float* y, int N){
|
||||
# // The `get_program_id(i)` returns the i-th coordinate
|
||||
# // of the program in the overaching SPMD context
|
||||
# // (a.k.a launch grid). This is what allows us to process
|
||||
# // different chunks of data in parallel.
|
||||
# // For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
# // is similar to blockIdx.{x,y,z}
|
||||
# int pid = get_program_id(0);
|
||||
# // In Triton, arrays are first-class citizen. In other words,
|
||||
# // they are primitives data-types and are -- contrary to C and
|
||||
# // CUDA -- not implemented as pointers to contiguous chunks of
|
||||
# // memory.
|
||||
# // In the few lines below, we create an array of `BLOCK` pointers
|
||||
# // whose memory values are, e.g.:
|
||||
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
# float* pz [BLOCK] = z + offset;
|
||||
# float* px [BLOCK] = x + offset;
|
||||
# float* py [BLOCK] = y + offset;
|
||||
# // Simple element-wise control-flow for load/store operations can
|
||||
# // be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
# // or the conditional dereferencing operator `*?(cond)ptr
|
||||
# // Here, we make sure that we do not access memory out-of-bounds when we
|
||||
# // write-back `z`
|
||||
# bool check[BLOCK] = offset < N;
|
||||
# *?(check)pz = *?(check)px + *?(check)py;
|
||||
# }
|
||||
#
|
||||
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
# %%
|
||||
# Torch bindings
|
||||
# --------------------------
|
||||
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
#
|
||||
# - :code:`source: string`: the source-code of the kernel you want to create
|
||||
# - :code:`device: torch.device`: the device you want to compile this code for
|
||||
# - :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# --------------------------
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
|
||||
# %%
|
||||
# Benchmarking
|
||||
# --------------------------
|
||||
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does
|
||||
|
||||
warmup = 10
|
||||
rep = 200
|
||||
for N in [2**i for i in range(17, 26, 1)]:
|
||||
x = torch.rand(N, device='cuda')
|
||||
y = torch.rand(N, device='cuda')
|
||||
triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)
|
||||
torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)
|
||||
# print the performance of triton and torch as well as the achieved bandwidth
|
||||
print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')
|
BIN
_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip
Normal file
186
_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax layer that outperform's PyTorch implementation and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Motivations
|
||||
# ------------
|
||||
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
def naive_softmax(x):
|
||||
# read MN elements ; write M elements
|
||||
x_max = torch.max(x, axis=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(x)
|
||||
# read MN elements ; write M elements
|
||||
denominator = torch.sum(numerator, axis=1)
|
||||
# read 2MN elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||
return ret
|
||||
|
||||
|
||||
# %%
|
||||
# When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
# Instead, we want to write a custom "fused" pytorch operators that only reads X once and does all the necessary computations on-chip.
|
||||
# This would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of 5x.
|
||||
# In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.
|
||||
|
||||
# %%
|
||||
# Compute Kernel
|
||||
# ----------------------------
|
||||
# Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:
|
||||
#
|
||||
# .. code-block:: C
|
||||
#
|
||||
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
# // row index
|
||||
# int m = get_program_id(0);
|
||||
# // column indices
|
||||
# int n [BLOCK] = 0 ... BLOCK;
|
||||
# // the memory address of all the elements
|
||||
# // that we want to load can be computed as follows
|
||||
# float* px [BLOCK] = X + m*stride_xm + n;
|
||||
# // because BLOCK has to be a power of two
|
||||
# // (per Triton-C specs), it is important
|
||||
# // to guard each memory operation with predicates
|
||||
# // or we will read out of bounds
|
||||
# bool check[BLOCK] = n < N;
|
||||
# float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
# // syntax for reduction in Triton is:
|
||||
# // x[..., OPERATOR, ...]
|
||||
# // ^
|
||||
# // index
|
||||
# // The operators currently supported are {min, max, +}
|
||||
# float z [BLOCK] = x - x[max];
|
||||
# // The exponential in Triton is fast but approximate
|
||||
# // (i.e., like __expf in CUDA)
|
||||
# float num [BLOCK] = exp(z);
|
||||
# float denom = num[+];
|
||||
# // The result of the reduction is now stored in y
|
||||
# float y [BLOCK] = num / denom;
|
||||
# // We write it back
|
||||
# float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
# *?(check)py = y;
|
||||
# }
|
||||
|
||||
# %%
|
||||
# Torch Bindings
|
||||
# ----------------------------
|
||||
# We need to make sure that BLOCK is the smallest power of two
|
||||
# greater than the number of rows N of the input matrix.
|
||||
# Different values of BLOCK will result in different kernels
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
_kernels = dict()
|
||||
|
||||
|
||||
def make_kernel(N, device):
|
||||
BLOCK = next_power_of_2(N)
|
||||
key = (BLOCK, device)
|
||||
if key not in _kernels:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
_kernels[key] = triton.kernel(_src, device=device, defines=defines)
|
||||
return _kernels[key]
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# here we just launch a grid of M programs
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# *launch kernel*:
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
# %%
|
||||
# Unit Test
|
||||
# ----------
|
||||
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_tri = softmax(x)
|
||||
y_ref = torch.softmax(x, axis=1)
|
||||
print(y_tri)
|
||||
print(y_ref)
|
||||
print(torch.allclose(y_tri, y_ref))
|
||||
|
||||
# %%
|
||||
# Seems to work!
|
||||
|
||||
# %%
|
||||
# Benchmarking
|
||||
# ----------
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
M = 4096
|
||||
Ns = [128 * i for i in range(2, 50)]
|
||||
tri_ms = []
|
||||
ref_ms = []
|
||||
def_ms = []
|
||||
for N in Ns:
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]
|
||||
ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]
|
||||
def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]
|
||||
plt.xlabel('N')
|
||||
plt.ylabel('Bandwidth (GB/s)')
|
||||
plt.plot(Ns, tri_ms, label='Triton')
|
||||
plt.plot(Ns, ref_ms, label='Torch')
|
||||
plt.plot(Ns, def_ms, label='Naive')
|
||||
plt.legend()
|
||||
plt.show()
|
104
_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb
Normal file
@@ -0,0 +1,104 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Vector Addition\nIn this tutorial, you will write a simple, high-performance vector addition using Triton and learn about:\n\n- The basic syntax of the Triton programming language\n- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API\n- The best practices for validating and benchmarking custom ops against native reference implementations\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compute Kernel\n\nEach compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel\non different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)\nprogramming model for more details).\n\n .. code-block:: C\n\n __global__ void add(float* z, float* x, float* y, int N){\n // The `get_program_id(i)` returns the i-th coordinate\n // of the program in the overaching SPMD context\n // (a.k.a launch grid). This is what allows us to process\n // different chunks of data in parallel.\n // For those similar with CUDA, `get_program_id({0,1,2})`\n // is similar to blockIdx.{x,y,z}\n int pid = get_program_id(0);\n // In Triton, arrays are first-class citizen. In other words,\n // they are primitives data-types and are -- contrary to C and\n // CUDA -- not implemented as pointers to contiguous chunks of\n // memory.\n // In the few lines below, we create an array of `BLOCK` pointers\n // whose memory values are, e.g.:\n // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz [BLOCK] = z + offset;\n float* px [BLOCK] = x + offset;\n float* py [BLOCK] = y + offset;\n // Simple element-wise control-flow for load/store operations can\n // be achieved using the the ternary operator `cond ? val_true : val_false`\n // or the conditional dereferencing operator `*?(cond)ptr\n // Here, we make sure that we do not access memory out-of-bounds when we\n // write-back `z`\n bool check[BLOCK] = offset < N;\n *?(check)pz = *?(check)px + *?(check)py;\n }\n\nThe existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Torch bindings\nThe only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:\n\n- :code:`source: string`: the source-code of the kernel you want to create\n- :code:`device: torch.device`: the device you want to compile this code for\n- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport triton\n\n# source-code for Triton compute kernel\n# here we just copy-paste the above code without the extensive comments.\n# you may prefer to store it in a .c file and load it from there instead.\n_src = \"\"\"\n__global__ void add(float* z, float* x, float* y, int N){\n // program id\n int pid = get_program_id(0);\n // create arrays of pointers\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz[BLOCK] = z + offset;\n float* px[BLOCK] = x + offset;\n float* py[BLOCK] = y + offset;\n // bounds checking\n bool check[BLOCK] = offset < N;\n // write-back\n *?(check)pz = *?(check)px + *?(check)py;\n}\n \"\"\"\n\n\n# This function returns a callable `triton.kernel` object created from the above source code.\n# For portability, we maintain a cache of kernels for different `torch.device`\n# We compile the kernel with -DBLOCK=1024\ndef make_add_kernel(device):\n cache = make_add_kernel.cache\n if device not in cache:\n defines = {'BLOCK': 1024}\n cache[device] = triton.kernel(_src, device=device, defines=defines)\n return cache[device]\n\n\nmake_add_kernel.cache = dict()\n\n\n# This is a standard torch custom autograd Function;\n# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`\nclass _add(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, y):\n # constraints of the op\n assert x.dtype == torch.float32\n # *allocate output*\n z = torch.empty_like(x)\n # *create launch grid*:\n # this is a function which takes compilation parameters `opt`\n # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n # triton.cdiv is a shortcut for ceil division:\n # triton.cdiv(a, b) = (a + b - 1) // b\n N = z.shape[0]\n grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n # *launch kernel*:\n # pointer to the data of torch tensors can be retrieved with\n # the `.data_ptr()` method\n kernel = make_add_kernel(z.device)\n kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)\n return z\n\n\n# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function\nadd = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Unit Test\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\nx = torch.rand(98432, device='cuda')\ny = torch.rand(98432, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Benchmarking\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"warmup = 10\nrep = 200\nfor N in [2**i for i in range(17, 26, 1)]:\n x = torch.rand(N, device='cuda')\n y = torch.rand(N, device='cuda')\n triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)\n torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)\n # print the performance of triton and torch as well as the achieved bandwidth\n print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
BIN
_images/sphx_glr_01-vector-add_thumb.png
Normal file
After Width: | Height: | Size: 26 KiB |
BIN
_images/sphx_glr_02-fused-softmax_001.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
_images/sphx_glr_02-fused-softmax_thumb.png
Normal file
After Width: | Height: | Size: 22 KiB |
Before Width: | Height: | Size: 18 KiB |
@@ -1,6 +1,21 @@
|
||||
==============
|
||||
Installation
|
||||
==============
|
||||
|
||||
--------------
|
||||
With Pip
|
||||
--------------
|
||||
|
||||
Triton can be installed directly from pip with the following command
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install triton
|
||||
|
||||
|
||||
--------------
|
||||
From Source
|
||||
==============
|
||||
--------------
|
||||
|
||||
+++++++++++++++
|
||||
Python Package
|
265
_sources/getting-started/tutorials/01-vector-add.rst.txt
Normal file
@@ -0,0 +1,265 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/01-vector-add.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_01-vector-add.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_01-vector-add.py:
|
||||
|
||||
|
||||
Vector Addition
|
||||
=================
|
||||
In this tutorial, you will write a simple, high-performance vector addition using Triton and learn about:
|
||||
|
||||
- The basic syntax of the Triton programming language
|
||||
- The best practices for creating PyTorch custom operators using the :code:`triton.kernel` Python API
|
||||
- The best practices for validating and benchmarking custom ops against native reference implementations
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-51
|
||||
|
||||
Compute Kernel
|
||||
--------------------------
|
||||
|
||||
Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||
on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||
programming model for more details).
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// The `get_program_id(i)` returns the i-th coordinate
|
||||
// of the program in the overaching SPMD context
|
||||
// (a.k.a launch grid). This is what allows us to process
|
||||
// different chunks of data in parallel.
|
||||
// For those similar with CUDA, `get_program_id({0,1,2})`
|
||||
// is similar to blockIdx.{x,y,z}
|
||||
int pid = get_program_id(0);
|
||||
// In Triton, arrays are first-class citizen. In other words,
|
||||
// they are primitives data-types and are -- contrary to C and
|
||||
// CUDA -- not implemented as pointers to contiguous chunks of
|
||||
// memory.
|
||||
// In the few lines below, we create an array of `BLOCK` pointers
|
||||
// whose memory values are, e.g.:
|
||||
// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||
// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz [BLOCK] = z + offset;
|
||||
float* px [BLOCK] = x + offset;
|
||||
float* py [BLOCK] = y + offset;
|
||||
// Simple element-wise control-flow for load/store operations can
|
||||
// be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||
// or the conditional dereferencing operator `*?(cond)ptr
|
||||
// Here, we make sure that we do not access memory out-of-bounds when we
|
||||
// write-back `z`
|
||||
bool check[BLOCK] = offset < N;
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
|
||||
The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 53-60
|
||||
|
||||
Torch bindings
|
||||
--------------------------
|
||||
The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
|
||||
|
||||
- :code:`source: string`: the source-code of the kernel you want to create
|
||||
- :code:`device: torch.device`: the device you want to compile this code for
|
||||
- :code:`defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 60-125
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# This function returns a callable `triton.kernel` object created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
def make_add_kernel(device):
|
||||
cache = make_add_kernel.cache
|
||||
if device not in cache:
|
||||
defines = {'BLOCK': 1024}
|
||||
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||
return cache[device]
|
||||
|
||||
|
||||
make_add_kernel.cache = dict()
|
||||
|
||||
|
||||
# This is a standard torch custom autograd Function;
|
||||
# The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
N = z.shape[0]
|
||||
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||
return z
|
||||
|
||||
|
||||
# Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 126-128
|
||||
|
||||
Unit Test
|
||||
--------------------------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 128-137
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432, device='cuda')
|
||||
y = torch.rand(98432, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
The maximum difference between torch and triton is 0.0
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 138-141
|
||||
|
||||
Benchmarking
|
||||
--------------------------
|
||||
We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 141-150
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
warmup = 10
|
||||
rep = 200
|
||||
for N in [2**i for i in range(17, 26, 1)]:
|
||||
x = torch.rand(N, device='cuda')
|
||||
y = torch.rand(N, device='cuda')
|
||||
triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)
|
||||
torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)
|
||||
# print the performance of triton and torch as well as the achieved bandwidth
|
||||
print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
131072 0.022 0.006
|
||||
262144 0.021 0.005
|
||||
524288 0.022 0.017
|
||||
1048576 0.037 0.037
|
||||
2097152 0.074 0.073
|
||||
4194304 0.144 0.143
|
||||
8388608 0.289 0.285
|
||||
16777216 0.566 0.562
|
||||
33554432 1.131 1.121
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 3.225 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 01-vector-add.py <01-vector-add.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 01-vector-add.ipynb <01-vector-add.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
328
_sources/getting-started/tutorials/02-fused-softmax.rst.txt
Normal file
@@ -0,0 +1,328 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/02-fused-softmax.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_02-fused-softmax.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
||||
|
||||
Fused Softmax
|
||||
=================
|
||||
In this tutorial, you will write a fused softmax layer that outperform's PyTorch implementation and learn about:
|
||||
|
||||
- The benefits of kernel fusion for bandwidth-bound operations.
|
||||
- The syntax and usage of reduction operators in Triton.
|
||||
- The automatic vectorization capabilities of the Triton compiler.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 12-16
|
||||
|
||||
Motivations
|
||||
------------
|
||||
Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 16-36
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Compute the row-wise softmax of x
|
||||
def naive_softmax(x):
|
||||
# read MN elements ; write M elements
|
||||
x_max = torch.max(x, axis=1)[0]
|
||||
# read 2MN elements ; write MN elements
|
||||
z = x - x_max[:, None]
|
||||
# read MN elements ; write MN elements
|
||||
numerator = torch.exp(x)
|
||||
# read MN elements ; write M elements
|
||||
denominator = torch.sum(numerator, axis=1)
|
||||
# read 2MN elements ; write MN elements
|
||||
ret = numerator / denominator[:, None]
|
||||
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 37-41
|
||||
|
||||
When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||
Instead, we want to write a custom "fused" pytorch operators that only reads X once and does all the necessary computations on-chip.
|
||||
This would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of 5x.
|
||||
In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 43-79
|
||||
|
||||
Compute Kernel
|
||||
----------------------------
|
||||
Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||
// row index
|
||||
int m = get_program_id(0);
|
||||
// column indices
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
// the memory address of all the elements
|
||||
// that we want to load can be computed as follows
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
// because BLOCK has to be a power of two
|
||||
// (per Triton-C specs), it is important
|
||||
// to guard each memory operation with predicates
|
||||
// or we will read out of bounds
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
// syntax for reduction in Triton is:
|
||||
// x[..., OPERATOR, ...]
|
||||
// ^
|
||||
// index
|
||||
// The operators currently supported are {min, max, +}
|
||||
float z [BLOCK] = x - x[max];
|
||||
// The exponential in Triton is fast but approximate
|
||||
// (i.e., like __expf in CUDA)
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
// The result of the reduction is now stored in y
|
||||
float y [BLOCK] = num / denom;
|
||||
// We write it back
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 81-86
|
||||
|
||||
Torch Bindings
|
||||
----------------------------
|
||||
We need to make sure that BLOCK is the smallest power of two
|
||||
greater than the number of rows N of the input matrix.
|
||||
Different values of BLOCK will result in different kernels
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 86-149
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# Source code for the Triton kernel
|
||||
_src = """
|
||||
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||
int m = get_program_id(0);
|
||||
int n [BLOCK] = 0 ... BLOCK;
|
||||
float* px [BLOCK] = X + m*stride_xm + n;
|
||||
bool check[BLOCK] = n < N;
|
||||
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||
float z [BLOCK] = x - x[max];
|
||||
float num [BLOCK] = exp(z);
|
||||
float denom = num[+];
|
||||
float y [BLOCK] = num / denom;
|
||||
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||
*?(check)py = y;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
_kernels = dict()
|
||||
|
||||
|
||||
def make_kernel(N, device):
|
||||
BLOCK = next_power_of_2(N)
|
||||
key = (BLOCK, device)
|
||||
if key not in _kernels:
|
||||
defines = {'BLOCK': BLOCK}
|
||||
_kernels[key] = triton.kernel(_src, device=device, defines=defines)
|
||||
return _kernels[key]
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
y = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# here we just launch a grid of M programs
|
||||
M, N = y.shape
|
||||
grid = lambda opt: (M, )
|
||||
# *launch kernel*:
|
||||
kernel = make_kernel(N, y.device)
|
||||
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||
return y
|
||||
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 150-152
|
||||
|
||||
Unit Test
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 152-160
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
x = torch.randn(1823, 781, device='cuda')
|
||||
y_tri = softmax(x)
|
||||
y_ref = torch.softmax(x, axis=1)
|
||||
print(y_tri)
|
||||
print(y_ref)
|
||||
print(torch.allclose(y_tri, y_ref))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([[2.0935e-03, 6.4551e-04, 9.8605e-05, ..., 3.3981e-04, 2.7386e-03,
|
||||
9.1986e-05],
|
||||
[7.0923e-04, 6.7521e-04, 5.1366e-04, ..., 9.8392e-04, 2.6547e-04,
|
||||
6.9062e-04],
|
||||
[1.4032e-04, 5.8826e-04, 1.1694e-03, ..., 6.6423e-04, 1.8178e-04,
|
||||
6.7049e-04],
|
||||
...,
|
||||
[1.1767e-03, 4.2703e-03, 6.0596e-04, ..., 9.5274e-04, 1.1681e-03,
|
||||
6.4924e-04],
|
||||
[1.0772e-04, 7.4854e-04, 3.1912e-03, ..., 2.4980e-04, 1.9012e-03,
|
||||
5.2567e-04],
|
||||
[2.8518e-03, 8.1899e-04, 7.7046e-04, ..., 1.3403e-03, 5.3167e-04,
|
||||
4.3268e-04]], device='cuda:0')
|
||||
tensor([[2.0935e-03, 6.4551e-04, 9.8605e-05, ..., 3.3981e-04, 2.7386e-03,
|
||||
9.1986e-05],
|
||||
[7.0923e-04, 6.7521e-04, 5.1366e-04, ..., 9.8392e-04, 2.6547e-04,
|
||||
6.9062e-04],
|
||||
[1.4032e-04, 5.8826e-04, 1.1694e-03, ..., 6.6423e-04, 1.8178e-04,
|
||||
6.7049e-04],
|
||||
...,
|
||||
[1.1767e-03, 4.2703e-03, 6.0596e-04, ..., 9.5274e-04, 1.1681e-03,
|
||||
6.4924e-04],
|
||||
[1.0772e-04, 7.4854e-04, 3.1912e-03, ..., 2.4980e-04, 1.9012e-03,
|
||||
5.2567e-04],
|
||||
[2.8518e-03, 8.1899e-04, 7.7046e-04, ..., 1.3403e-03, 5.3167e-04,
|
||||
4.3268e-04]], device='cuda:0')
|
||||
True
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 161-162
|
||||
|
||||
Seems to work!
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 164-166
|
||||
|
||||
Benchmarking
|
||||
----------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 166-186
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
M = 4096
|
||||
Ns = [128 * i for i in range(2, 50)]
|
||||
tri_ms = []
|
||||
ref_ms = []
|
||||
def_ms = []
|
||||
for N in Ns:
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]
|
||||
ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]
|
||||
def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]
|
||||
plt.xlabel('N')
|
||||
plt.ylabel('Bandwidth (GB/s)')
|
||||
plt.plot(Ns, tri_ms, label='Triton')
|
||||
plt.plot(Ns, ref_ms, label='Torch')
|
||||
plt.plot(Ns, def_ms, label='Naive')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
|
||||
:alt: 02 fused softmax
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 5.758 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
81
_sources/getting-started/tutorials/index.rst.txt
Normal file
@@ -0,0 +1,81 @@
|
||||
:orphan:
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials:
|
||||
|
||||
Tutorials
|
||||
==================
|
||||
|
||||
Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic syntax of the Triton programming language - The best practices for creating PyTorch...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_01-vector-add_thumb.png
|
||||
:alt: Vector Addition
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_01-vector-add.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/01-vector-add
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The syntax and usage of reduc...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_02-fused-softmax_thumb.png
|
||||
:alt: Fused Softmax
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/02-fused-softmax
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-gallery
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download all examples in Python source code: tutorials_python.zip </getting-started/tutorials/tutorials_python.zip>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download all examples in Jupyter notebooks: tutorials_jupyter.zip </getting-started/tutorials/tutorials_jupyter.zip>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,14 @@
|
||||
|
||||
:orphan:
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_sg_execution_times:
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**00:08.983** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 00:05.758 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:03.225 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------+-----------+--------+
|
@@ -1,21 +1,18 @@
|
||||
.. Triton documentation master file, created by
|
||||
sphinx-quickstart on Mon Feb 10 01:01:37 2020.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to Triton's documentation!
|
||||
==================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Installation Instructions
|
||||
Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
|
||||
installation/packaged-binaries
|
||||
installation/from-source
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
- Follow the :doc:`installation instructions <getting-started/installation>` for your platform of choice.
|
||||
- Take a look at the :doc:`tutorials <getting-started/tutorials/index>` to learn how to write your first Triton program.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Tutorials
|
||||
:caption: Getting Started
|
||||
:hidden:
|
||||
|
||||
Vector Addition <tutorials/01-vector-add.ipynb>
|
||||
Fused Softmax <tutorials/02-fused-softmax.ipynb>
|
||||
getting-started/installation
|
||||
getting-started/tutorials/index
|
@@ -1,8 +0,0 @@
|
||||
Packaged Binaries
|
||||
=================
|
||||
|
||||
Triton can be installed directly from pip with the following command
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install triton
|
@@ -1,329 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "acute-possession",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Vector Addition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Before Width: | Height: | Size: 673 B |
1
_static/binder_badge_logo.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="109" height="20"><linearGradient id="b" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="a"><rect width="109" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#a)"><path fill="#555" d="M0 0h64v20H0z"/><path fill="#579aca" d="M64 0h45v20H64z"/><path fill="url(#b)" d="M0 0h109v20H0z"/></g><g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="110"><image x="5" y="3" width="14" height="14" xlink:href=""/> <text x="415" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370">launch</text><text x="415" y="140" transform="scale(.1)" textLength="370">launch</text><text x="855" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">binder</text><text x="855" y="140" transform="scale(.1)" textLength="350">binder</text></g> </svg>
|
After Width: | Height: | Size: 3.3 KiB |
BIN
_static/broken_example.png
Normal file
After Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 756 B |
Before Width: | Height: | Size: 829 B |
Before Width: | Height: | Size: 641 B |
Before Width: | Height: | Size: 222 B |
BIN
_static/down.png
Before Width: | Height: | Size: 202 B |
Before Width: | Height: | Size: 434 KiB |
6
_static/gallery-binder.css
Normal file
@@ -0,0 +1,6 @@
|
||||
/* CSS for binder integration */
|
||||
|
||||
div.binder-badge {
|
||||
margin: 1em auto;
|
||||
vertical-align: middle;
|
||||
}
|
36
_static/gallery-dataframe.css
Normal file
@@ -0,0 +1,36 @@
|
||||
/* Pandas dataframe css */
|
||||
/* Taken from: https://github.com/spatialaudio/nbsphinx/blob/fb3ba670fc1ba5f54d4c487573dbc1b4ecf7e9ff/src/nbsphinx.py#L587-L619 */
|
||||
|
||||
table.dataframe {
|
||||
border: none !important;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
border-color: transparent;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
table.dataframe thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
table.dataframe tr,
|
||||
table.dataframe th,
|
||||
table.dataframe td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
table.dataframe th {
|
||||
font-weight: bold;
|
||||
}
|
||||
table.dataframe tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
table.dataframe tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
209
_static/gallery-rendered-html.css
Normal file
@@ -0,0 +1,209 @@
|
||||
/* Adapted from notebook/static/style/style.min.css */
|
||||
|
||||
.rendered_html {
|
||||
color: #000;
|
||||
/* any extras will just be numbers: */
|
||||
}
|
||||
.rendered_html em {
|
||||
font-style: italic;
|
||||
}
|
||||
.rendered_html strong {
|
||||
font-weight: bold;
|
||||
}
|
||||
.rendered_html u {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.rendered_html :link {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.rendered_html :visited {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.rendered_html h1 {
|
||||
font-size: 185.7%;
|
||||
margin: 1.08em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
}
|
||||
.rendered_html h2 {
|
||||
font-size: 157.1%;
|
||||
margin: 1.27em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
}
|
||||
.rendered_html h3 {
|
||||
font-size: 128.6%;
|
||||
margin: 1.55em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
}
|
||||
.rendered_html h4 {
|
||||
font-size: 100%;
|
||||
margin: 2em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
}
|
||||
.rendered_html h5 {
|
||||
font-size: 100%;
|
||||
margin: 2em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
font-style: italic;
|
||||
}
|
||||
.rendered_html h6 {
|
||||
font-size: 100%;
|
||||
margin: 2em 0 0 0;
|
||||
font-weight: bold;
|
||||
line-height: 1.0;
|
||||
font-style: italic;
|
||||
}
|
||||
.rendered_html h1:first-child {
|
||||
margin-top: 0.538em;
|
||||
}
|
||||
.rendered_html h2:first-child {
|
||||
margin-top: 0.636em;
|
||||
}
|
||||
.rendered_html h3:first-child {
|
||||
margin-top: 0.777em;
|
||||
}
|
||||
.rendered_html h4:first-child {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html h5:first-child {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html h6:first-child {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html ul:not(.list-inline),
|
||||
.rendered_html ol:not(.list-inline) {
|
||||
padding-left: 2em;
|
||||
}
|
||||
.rendered_html ul {
|
||||
list-style: disc;
|
||||
}
|
||||
.rendered_html ul ul {
|
||||
list-style: square;
|
||||
margin-top: 0;
|
||||
}
|
||||
.rendered_html ul ul ul {
|
||||
list-style: circle;
|
||||
}
|
||||
.rendered_html ol {
|
||||
list-style: decimal;
|
||||
}
|
||||
.rendered_html ol ol {
|
||||
list-style: upper-alpha;
|
||||
margin-top: 0;
|
||||
}
|
||||
.rendered_html ol ol ol {
|
||||
list-style: lower-alpha;
|
||||
}
|
||||
.rendered_html ol ol ol ol {
|
||||
list-style: lower-roman;
|
||||
}
|
||||
.rendered_html ol ol ol ol ol {
|
||||
list-style: decimal;
|
||||
}
|
||||
.rendered_html * + ul {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html * + ol {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html hr {
|
||||
color: black;
|
||||
background-color: black;
|
||||
}
|
||||
.rendered_html pre {
|
||||
margin: 1em 2em;
|
||||
padding: 0px;
|
||||
background-color: #fff;
|
||||
}
|
||||
.rendered_html code {
|
||||
background-color: #eff0f1;
|
||||
}
|
||||
.rendered_html p code {
|
||||
padding: 1px 5px;
|
||||
}
|
||||
.rendered_html pre code {
|
||||
background-color: #fff;
|
||||
}
|
||||
.rendered_html pre,
|
||||
.rendered_html code {
|
||||
border: 0;
|
||||
color: #000;
|
||||
font-size: 100%;
|
||||
}
|
||||
.rendered_html blockquote {
|
||||
margin: 1em 2em;
|
||||
}
|
||||
.rendered_html table {
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
border: none;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
.rendered_html thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
.rendered_html tr,
|
||||
.rendered_html th,
|
||||
.rendered_html td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
.rendered_html th {
|
||||
font-weight: bold;
|
||||
}
|
||||
.rendered_html tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.rendered_html tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
||||
.rendered_html * + table {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html p {
|
||||
text-align: left;
|
||||
}
|
||||
.rendered_html * + p {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html img {
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
.rendered_html * + img {
|
||||
margin-top: 1em;
|
||||
}
|
||||
.rendered_html img,
|
||||
.rendered_html svg {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
.rendered_html img.unconfined,
|
||||
.rendered_html svg.unconfined {
|
||||
max-width: none;
|
||||
}
|
||||
.rendered_html .alert {
|
||||
margin-bottom: initial;
|
||||
}
|
||||
.rendered_html * + .alert {
|
||||
margin-top: 1em;
|
||||
}
|
||||
[dir="rtl"] .rendered_html p {
|
||||
text-align: right;
|
||||
}
|
204
_static/gallery.css
Normal file
@@ -0,0 +1,204 @@
|
||||
/*
|
||||
Sphinx-Gallery has compatible CSS to fix default sphinx themes
|
||||
Tested for Sphinx 1.3.1 for all themes: default, alabaster, sphinxdoc,
|
||||
scrolls, agogo, traditional, nature, haiku, pyramid
|
||||
Tested for Read the Docs theme 0.1.7 */
|
||||
.sphx-glr-thumbcontainer {
|
||||
background: #fff;
|
||||
border: solid #fff 1px;
|
||||
-moz-border-radius: 5px;
|
||||
-webkit-border-radius: 5px;
|
||||
border-radius: 5px;
|
||||
box-shadow: none;
|
||||
float: left;
|
||||
margin: 5px;
|
||||
min-height: 230px;
|
||||
padding-top: 5px;
|
||||
position: relative;
|
||||
}
|
||||
.sphx-glr-thumbcontainer:hover {
|
||||
border: solid #b4ddfc 1px;
|
||||
box-shadow: 0 0 15px rgba(142, 176, 202, 0.5);
|
||||
}
|
||||
.sphx-glr-thumbcontainer a.internal {
|
||||
bottom: 0;
|
||||
display: block;
|
||||
left: 0;
|
||||
padding: 150px 10px 0;
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 0;
|
||||
}
|
||||
/* Next one is to avoid Sphinx traditional theme to cover all the
|
||||
thumbnail with its default link Background color */
|
||||
.sphx-glr-thumbcontainer a.internal:hover {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.sphx-glr-thumbcontainer p {
|
||||
margin: 0 0 .1em 0;
|
||||
}
|
||||
.sphx-glr-thumbcontainer .figure {
|
||||
margin: 10px;
|
||||
width: 160px;
|
||||
}
|
||||
.sphx-glr-thumbcontainer img {
|
||||
display: inline;
|
||||
max-height: 112px;
|
||||
max-width: 160px;
|
||||
}
|
||||
.sphx-glr-thumbcontainer[tooltip]:hover:after {
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
-webkit-border-radius: 5px;
|
||||
-moz-border-radius: 5px;
|
||||
border-radius: 5px;
|
||||
color: #fff;
|
||||
content: attr(tooltip);
|
||||
left: 95%;
|
||||
padding: 5px 15px;
|
||||
position: absolute;
|
||||
z-index: 98;
|
||||
width: 220px;
|
||||
bottom: 52%;
|
||||
}
|
||||
.sphx-glr-thumbcontainer[tooltip]:hover:before {
|
||||
border: solid;
|
||||
border-color: #333 transparent;
|
||||
border-width: 18px 0 0 20px;
|
||||
bottom: 58%;
|
||||
content: '';
|
||||
left: 85%;
|
||||
position: absolute;
|
||||
z-index: 99;
|
||||
}
|
||||
|
||||
.sphx-glr-script-out {
|
||||
color: #888;
|
||||
margin: 0;
|
||||
}
|
||||
p.sphx-glr-script-out {
|
||||
padding-top: 0.7em;
|
||||
}
|
||||
.sphx-glr-script-out .highlight {
|
||||
background-color: transparent;
|
||||
margin-left: 2.5em;
|
||||
margin-top: -2.1em;
|
||||
}
|
||||
.sphx-glr-script-out .highlight pre {
|
||||
background-color: #fafae2;
|
||||
border: 0;
|
||||
max-height: 30em;
|
||||
overflow: auto;
|
||||
padding-left: 1ex;
|
||||
margin: 0px;
|
||||
word-break: break-word;
|
||||
}
|
||||
.sphx-glr-script-out + p {
|
||||
margin-top: 1.8em;
|
||||
}
|
||||
blockquote.sphx-glr-script-out {
|
||||
margin-left: 0pt;
|
||||
}
|
||||
.sphx-glr-script-out.highlight-pytb .highlight pre {
|
||||
color: #000;
|
||||
background-color: #ffe4e4;
|
||||
border: 1px solid #f66;
|
||||
margin-top: 10px;
|
||||
padding: 7px;
|
||||
}
|
||||
|
||||
div.sphx-glr-footer {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
div.sphx-glr-download {
|
||||
margin: 1em auto;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
div.sphx-glr-download a {
|
||||
background-color: #ffc;
|
||||
background-image: linear-gradient(to bottom, #FFC, #d5d57e);
|
||||
border-radius: 4px;
|
||||
border: 1px solid #c2c22d;
|
||||
color: #000;
|
||||
display: inline-block;
|
||||
font-weight: bold;
|
||||
padding: 1ex;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
div.sphx-glr-download code.download {
|
||||
display: inline-block;
|
||||
white-space: normal;
|
||||
word-break: normal;
|
||||
overflow-wrap: break-word;
|
||||
/* border and background are given by the enclosing 'a' */
|
||||
border: none;
|
||||
background: none;
|
||||
}
|
||||
|
||||
div.sphx-glr-download a:hover {
|
||||
box-shadow: inset 0 1px 0 rgba(255,255,255,.1), 0 1px 5px rgba(0,0,0,.25);
|
||||
text-decoration: none;
|
||||
background-image: none;
|
||||
background-color: #d5d57e;
|
||||
}
|
||||
|
||||
.sphx-glr-example-title > :target::before {
|
||||
display: block;
|
||||
content: "";
|
||||
margin-top: -50px;
|
||||
height: 50px;
|
||||
visibility: hidden;
|
||||
}
|
||||
|
||||
ul.sphx-glr-horizontal {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
}
|
||||
ul.sphx-glr-horizontal li {
|
||||
display: inline;
|
||||
}
|
||||
ul.sphx-glr-horizontal img {
|
||||
height: auto !important;
|
||||
}
|
||||
|
||||
.sphx-glr-single-img {
|
||||
margin: auto;
|
||||
display: block;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.sphx-glr-multi-img {
|
||||
max-width: 42%;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
div.sphx-glr-animation {
|
||||
margin: auto;
|
||||
display: block;
|
||||
max-width: 100%;
|
||||
}
|
||||
div.sphx-glr-animation .animation{
|
||||
display: block;
|
||||
}
|
||||
|
||||
p.sphx-glr-signature a.reference.external {
|
||||
-moz-border-radius: 5px;
|
||||
-webkit-border-radius: 5px;
|
||||
border-radius: 5px;
|
||||
padding: 3px;
|
||||
font-size: 75%;
|
||||
text-align: right;
|
||||
margin-left: auto;
|
||||
display: table;
|
||||
}
|
||||
|
||||
.sphx-glr-clear{
|
||||
clear: both;
|
||||
}
|
||||
|
||||
a.sphx-glr-backref-instance {
|
||||
text-decoration: none;
|
||||
}
|
1
_static/js/modernizr.min.js
vendored
BIN
_static/no_image.png
Normal file
After Width: | Height: | Size: 4.2 KiB |
Before Width: | Height: | Size: 214 B |
BIN
_static/up.png
Before Width: | Height: | Size: 203 B |
@@ -1,808 +0,0 @@
|
||||
/*
|
||||
* websupport.js
|
||||
* ~~~~~~~~~~~~~
|
||||
*
|
||||
* sphinx.websupport utilities for all documentation.
|
||||
*
|
||||
* :copyright: Copyright 2007-2018 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
|
||||
(function($) {
|
||||
$.fn.autogrow = function() {
|
||||
return this.each(function() {
|
||||
var textarea = this;
|
||||
|
||||
$.fn.autogrow.resize(textarea);
|
||||
|
||||
$(textarea)
|
||||
.focus(function() {
|
||||
textarea.interval = setInterval(function() {
|
||||
$.fn.autogrow.resize(textarea);
|
||||
}, 500);
|
||||
})
|
||||
.blur(function() {
|
||||
clearInterval(textarea.interval);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
$.fn.autogrow.resize = function(textarea) {
|
||||
var lineHeight = parseInt($(textarea).css('line-height'), 10);
|
||||
var lines = textarea.value.split('\n');
|
||||
var columns = textarea.cols;
|
||||
var lineCount = 0;
|
||||
$.each(lines, function() {
|
||||
lineCount += Math.ceil(this.length / columns) || 1;
|
||||
});
|
||||
var height = lineHeight * (lineCount + 1);
|
||||
$(textarea).css('height', height);
|
||||
};
|
||||
})(jQuery);
|
||||
|
||||
(function($) {
|
||||
var comp, by;
|
||||
|
||||
function init() {
|
||||
initEvents();
|
||||
initComparator();
|
||||
}
|
||||
|
||||
function initEvents() {
|
||||
$(document).on("click", 'a.comment-close', function(event) {
|
||||
event.preventDefault();
|
||||
hide($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.vote', function(event) {
|
||||
event.preventDefault();
|
||||
handleVote($(this));
|
||||
});
|
||||
$(document).on("click", 'a.reply', function(event) {
|
||||
event.preventDefault();
|
||||
openReply($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.close-reply', function(event) {
|
||||
event.preventDefault();
|
||||
closeReply($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.sort-option', function(event) {
|
||||
event.preventDefault();
|
||||
handleReSort($(this));
|
||||
});
|
||||
$(document).on("click", 'a.show-proposal', function(event) {
|
||||
event.preventDefault();
|
||||
showProposal($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.hide-proposal', function(event) {
|
||||
event.preventDefault();
|
||||
hideProposal($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.show-propose-change', function(event) {
|
||||
event.preventDefault();
|
||||
showProposeChange($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.hide-propose-change', function(event) {
|
||||
event.preventDefault();
|
||||
hideProposeChange($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.accept-comment', function(event) {
|
||||
event.preventDefault();
|
||||
acceptComment($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.delete-comment', function(event) {
|
||||
event.preventDefault();
|
||||
deleteComment($(this).attr('id').substring(2));
|
||||
});
|
||||
$(document).on("click", 'a.comment-markup', function(event) {
|
||||
event.preventDefault();
|
||||
toggleCommentMarkupBox($(this).attr('id').substring(2));
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Set comp, which is a comparator function used for sorting and
|
||||
* inserting comments into the list.
|
||||
*/
|
||||
function setComparator() {
|
||||
// If the first three letters are "asc", sort in ascending order
|
||||
// and remove the prefix.
|
||||
if (by.substring(0,3) == 'asc') {
|
||||
var i = by.substring(3);
|
||||
comp = function(a, b) { return a[i] - b[i]; };
|
||||
} else {
|
||||
// Otherwise sort in descending order.
|
||||
comp = function(a, b) { return b[by] - a[by]; };
|
||||
}
|
||||
|
||||
// Reset link styles and format the selected sort option.
|
||||
$('a.sel').attr('href', '#').removeClass('sel');
|
||||
$('a.by' + by).removeAttr('href').addClass('sel');
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a comp function. If the user has preferences stored in
|
||||
* the sortBy cookie, use those, otherwise use the default.
|
||||
*/
|
||||
function initComparator() {
|
||||
by = 'rating'; // Default to sort by rating.
|
||||
// If the sortBy cookie is set, use that instead.
|
||||
if (document.cookie.length > 0) {
|
||||
var start = document.cookie.indexOf('sortBy=');
|
||||
if (start != -1) {
|
||||
start = start + 7;
|
||||
var end = document.cookie.indexOf(";", start);
|
||||
if (end == -1) {
|
||||
end = document.cookie.length;
|
||||
by = unescape(document.cookie.substring(start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
setComparator();
|
||||
}
|
||||
|
||||
/**
|
||||
* Show a comment div.
|
||||
*/
|
||||
function show(id) {
|
||||
$('#ao' + id).hide();
|
||||
$('#ah' + id).show();
|
||||
var context = $.extend({id: id}, opts);
|
||||
var popup = $(renderTemplate(popupTemplate, context)).hide();
|
||||
popup.find('textarea[name="proposal"]').hide();
|
||||
popup.find('a.by' + by).addClass('sel');
|
||||
var form = popup.find('#cf' + id);
|
||||
form.submit(function(event) {
|
||||
event.preventDefault();
|
||||
addComment(form);
|
||||
});
|
||||
$('#s' + id).after(popup);
|
||||
popup.slideDown('fast', function() {
|
||||
getComments(id);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Hide a comment div.
|
||||
*/
|
||||
function hide(id) {
|
||||
$('#ah' + id).hide();
|
||||
$('#ao' + id).show();
|
||||
var div = $('#sc' + id);
|
||||
div.slideUp('fast', function() {
|
||||
div.remove();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform an ajax request to get comments for a node
|
||||
* and insert the comments into the comments tree.
|
||||
*/
|
||||
function getComments(id) {
|
||||
$.ajax({
|
||||
type: 'GET',
|
||||
url: opts.getCommentsURL,
|
||||
data: {node: id},
|
||||
success: function(data, textStatus, request) {
|
||||
var ul = $('#cl' + id);
|
||||
var speed = 100;
|
||||
$('#cf' + id)
|
||||
.find('textarea[name="proposal"]')
|
||||
.data('source', data.source);
|
||||
|
||||
if (data.comments.length === 0) {
|
||||
ul.html('<li>No comments yet.</li>');
|
||||
ul.data('empty', true);
|
||||
} else {
|
||||
// If there are comments, sort them and put them in the list.
|
||||
var comments = sortComments(data.comments);
|
||||
speed = data.comments.length * 100;
|
||||
appendComments(comments, ul);
|
||||
ul.data('empty', false);
|
||||
}
|
||||
$('#cn' + id).slideUp(speed + 200);
|
||||
ul.slideDown(speed);
|
||||
},
|
||||
error: function(request, textStatus, error) {
|
||||
showError('Oops, there was a problem retrieving the comments.');
|
||||
},
|
||||
dataType: 'json'
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a comment via ajax and insert the comment into the comment tree.
|
||||
*/
|
||||
function addComment(form) {
|
||||
var node_id = form.find('input[name="node"]').val();
|
||||
var parent_id = form.find('input[name="parent"]').val();
|
||||
var text = form.find('textarea[name="comment"]').val();
|
||||
var proposal = form.find('textarea[name="proposal"]').val();
|
||||
|
||||
if (text == '') {
|
||||
showError('Please enter a comment.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Disable the form that is being submitted.
|
||||
form.find('textarea,input').attr('disabled', 'disabled');
|
||||
|
||||
// Send the comment to the server.
|
||||
$.ajax({
|
||||
type: "POST",
|
||||
url: opts.addCommentURL,
|
||||
dataType: 'json',
|
||||
data: {
|
||||
node: node_id,
|
||||
parent: parent_id,
|
||||
text: text,
|
||||
proposal: proposal
|
||||
},
|
||||
success: function(data, textStatus, error) {
|
||||
// Reset the form.
|
||||
if (node_id) {
|
||||
hideProposeChange(node_id);
|
||||
}
|
||||
form.find('textarea')
|
||||
.val('')
|
||||
.add(form.find('input'))
|
||||
.removeAttr('disabled');
|
||||
var ul = $('#cl' + (node_id || parent_id));
|
||||
if (ul.data('empty')) {
|
||||
$(ul).empty();
|
||||
ul.data('empty', false);
|
||||
}
|
||||
insertComment(data.comment);
|
||||
var ao = $('#ao' + node_id);
|
||||
ao.find('img').attr({'src': opts.commentBrightImage});
|
||||
if (node_id) {
|
||||
// if this was a "root" comment, remove the commenting box
|
||||
// (the user can get it back by reopening the comment popup)
|
||||
$('#ca' + node_id).slideUp();
|
||||
}
|
||||
},
|
||||
error: function(request, textStatus, error) {
|
||||
form.find('textarea,input').removeAttr('disabled');
|
||||
showError('Oops, there was a problem adding the comment.');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively append comments to the main comment list and children
|
||||
* lists, creating the comment tree.
|
||||
*/
|
||||
function appendComments(comments, ul) {
|
||||
$.each(comments, function() {
|
||||
var div = createCommentDiv(this);
|
||||
ul.append($(document.createElement('li')).html(div));
|
||||
appendComments(this.children, div.find('ul.comment-children'));
|
||||
// To avoid stagnating data, don't store the comments children in data.
|
||||
this.children = null;
|
||||
div.data('comment', this);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* After adding a new comment, it must be inserted in the correct
|
||||
* location in the comment tree.
|
||||
*/
|
||||
function insertComment(comment) {
|
||||
var div = createCommentDiv(comment);
|
||||
|
||||
// To avoid stagnating data, don't store the comments children in data.
|
||||
comment.children = null;
|
||||
div.data('comment', comment);
|
||||
|
||||
var ul = $('#cl' + (comment.node || comment.parent));
|
||||
var siblings = getChildren(ul);
|
||||
|
||||
var li = $(document.createElement('li'));
|
||||
li.hide();
|
||||
|
||||
// Determine where in the parents children list to insert this comment.
|
||||
for(i=0; i < siblings.length; i++) {
|
||||
if (comp(comment, siblings[i]) <= 0) {
|
||||
$('#cd' + siblings[i].id)
|
||||
.parent()
|
||||
.before(li.html(div));
|
||||
li.slideDown('fast');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, this comment rates lower than all the others,
|
||||
// or it is the only comment in the list.
|
||||
ul.append(li.html(div));
|
||||
li.slideDown('fast');
|
||||
}
|
||||
|
||||
function acceptComment(id) {
|
||||
$.ajax({
|
||||
type: 'POST',
|
||||
url: opts.acceptCommentURL,
|
||||
data: {id: id},
|
||||
success: function(data, textStatus, request) {
|
||||
$('#cm' + id).fadeOut('fast');
|
||||
$('#cd' + id).removeClass('moderate');
|
||||
},
|
||||
error: function(request, textStatus, error) {
|
||||
showError('Oops, there was a problem accepting the comment.');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function deleteComment(id) {
|
||||
$.ajax({
|
||||
type: 'POST',
|
||||
url: opts.deleteCommentURL,
|
||||
data: {id: id},
|
||||
success: function(data, textStatus, request) {
|
||||
var div = $('#cd' + id);
|
||||
if (data == 'delete') {
|
||||
// Moderator mode: remove the comment and all children immediately
|
||||
div.slideUp('fast', function() {
|
||||
div.remove();
|
||||
});
|
||||
return;
|
||||
}
|
||||
// User mode: only mark the comment as deleted
|
||||
div
|
||||
.find('span.user-id:first')
|
||||
.text('[deleted]').end()
|
||||
.find('div.comment-text:first')
|
||||
.text('[deleted]').end()
|
||||
.find('#cm' + id + ', #dc' + id + ', #ac' + id + ', #rc' + id +
|
||||
', #sp' + id + ', #hp' + id + ', #cr' + id + ', #rl' + id)
|
||||
.remove();
|
||||
var comment = div.data('comment');
|
||||
comment.username = '[deleted]';
|
||||
comment.text = '[deleted]';
|
||||
div.data('comment', comment);
|
||||
},
|
||||
error: function(request, textStatus, error) {
|
||||
showError('Oops, there was a problem deleting the comment.');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function showProposal(id) {
|
||||
$('#sp' + id).hide();
|
||||
$('#hp' + id).show();
|
||||
$('#pr' + id).slideDown('fast');
|
||||
}
|
||||
|
||||
function hideProposal(id) {
|
||||
$('#hp' + id).hide();
|
||||
$('#sp' + id).show();
|
||||
$('#pr' + id).slideUp('fast');
|
||||
}
|
||||
|
||||
function showProposeChange(id) {
|
||||
$('#pc' + id).hide();
|
||||
$('#hc' + id).show();
|
||||
var textarea = $('#pt' + id);
|
||||
textarea.val(textarea.data('source'));
|
||||
$.fn.autogrow.resize(textarea[0]);
|
||||
textarea.slideDown('fast');
|
||||
}
|
||||
|
||||
function hideProposeChange(id) {
|
||||
$('#hc' + id).hide();
|
||||
$('#pc' + id).show();
|
||||
var textarea = $('#pt' + id);
|
||||
textarea.val('').removeAttr('disabled');
|
||||
textarea.slideUp('fast');
|
||||
}
|
||||
|
||||
function toggleCommentMarkupBox(id) {
|
||||
$('#mb' + id).toggle();
|
||||
}
|
||||
|
||||
/** Handle when the user clicks on a sort by link. */
|
||||
function handleReSort(link) {
|
||||
var classes = link.attr('class').split(/\s+/);
|
||||
for (var i=0; i<classes.length; i++) {
|
||||
if (classes[i] != 'sort-option') {
|
||||
by = classes[i].substring(2);
|
||||
}
|
||||
}
|
||||
setComparator();
|
||||
// Save/update the sortBy cookie.
|
||||
var expiration = new Date();
|
||||
expiration.setDate(expiration.getDate() + 365);
|
||||
document.cookie= 'sortBy=' + escape(by) +
|
||||
';expires=' + expiration.toUTCString();
|
||||
$('ul.comment-ul').each(function(index, ul) {
|
||||
var comments = getChildren($(ul), true);
|
||||
comments = sortComments(comments);
|
||||
appendComments(comments, $(ul).empty());
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to process a vote when a user clicks an arrow.
|
||||
*/
|
||||
function handleVote(link) {
|
||||
if (!opts.voting) {
|
||||
showError("You'll need to login to vote.");
|
||||
return;
|
||||
}
|
||||
|
||||
var id = link.attr('id');
|
||||
if (!id) {
|
||||
// Didn't click on one of the voting arrows.
|
||||
return;
|
||||
}
|
||||
// If it is an unvote, the new vote value is 0,
|
||||
// Otherwise it's 1 for an upvote, or -1 for a downvote.
|
||||
var value = 0;
|
||||
if (id.charAt(1) != 'u') {
|
||||
value = id.charAt(0) == 'u' ? 1 : -1;
|
||||
}
|
||||
// The data to be sent to the server.
|
||||
var d = {
|
||||
comment_id: id.substring(2),
|
||||
value: value
|
||||
};
|
||||
|
||||
// Swap the vote and unvote links.
|
||||
link.hide();
|
||||
$('#' + id.charAt(0) + (id.charAt(1) == 'u' ? 'v' : 'u') + d.comment_id)
|
||||
.show();
|
||||
|
||||
// The div the comment is displayed in.
|
||||
var div = $('div#cd' + d.comment_id);
|
||||
var data = div.data('comment');
|
||||
|
||||
// If this is not an unvote, and the other vote arrow has
|
||||
// already been pressed, unpress it.
|
||||
if ((d.value !== 0) && (data.vote === d.value * -1)) {
|
||||
$('#' + (d.value == 1 ? 'd' : 'u') + 'u' + d.comment_id).hide();
|
||||
$('#' + (d.value == 1 ? 'd' : 'u') + 'v' + d.comment_id).show();
|
||||
}
|
||||
|
||||
// Update the comments rating in the local data.
|
||||
data.rating += (data.vote === 0) ? d.value : (d.value - data.vote);
|
||||
data.vote = d.value;
|
||||
div.data('comment', data);
|
||||
|
||||
// Change the rating text.
|
||||
div.find('.rating:first')
|
||||
.text(data.rating + ' point' + (data.rating == 1 ? '' : 's'));
|
||||
|
||||
// Send the vote information to the server.
|
||||
$.ajax({
|
||||
type: "POST",
|
||||
url: opts.processVoteURL,
|
||||
data: d,
|
||||
error: function(request, textStatus, error) {
|
||||
showError('Oops, there was a problem casting that vote.');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a reply form used to reply to an existing comment.
|
||||
*/
|
||||
function openReply(id) {
|
||||
// Swap out the reply link for the hide link
|
||||
$('#rl' + id).hide();
|
||||
$('#cr' + id).show();
|
||||
|
||||
// Add the reply li to the children ul.
|
||||
var div = $(renderTemplate(replyTemplate, {id: id})).hide();
|
||||
$('#cl' + id)
|
||||
.prepend(div)
|
||||
// Setup the submit handler for the reply form.
|
||||
.find('#rf' + id)
|
||||
.submit(function(event) {
|
||||
event.preventDefault();
|
||||
addComment($('#rf' + id));
|
||||
closeReply(id);
|
||||
})
|
||||
.find('input[type=button]')
|
||||
.click(function() {
|
||||
closeReply(id);
|
||||
});
|
||||
div.slideDown('fast', function() {
|
||||
$('#rf' + id).find('textarea').focus();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the reply form opened with openReply.
|
||||
*/
|
||||
function closeReply(id) {
|
||||
// Remove the reply div from the DOM.
|
||||
$('#rd' + id).slideUp('fast', function() {
|
||||
$(this).remove();
|
||||
});
|
||||
|
||||
// Swap out the hide link for the reply link
|
||||
$('#cr' + id).hide();
|
||||
$('#rl' + id).show();
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively sort a tree of comments using the comp comparator.
|
||||
*/
|
||||
function sortComments(comments) {
|
||||
comments.sort(comp);
|
||||
$.each(comments, function() {
|
||||
this.children = sortComments(this.children);
|
||||
});
|
||||
return comments;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the children comments from a ul. If recursive is true,
|
||||
* recursively include childrens' children.
|
||||
*/
|
||||
function getChildren(ul, recursive) {
|
||||
var children = [];
|
||||
ul.children().children("[id^='cd']")
|
||||
.each(function() {
|
||||
var comment = $(this).data('comment');
|
||||
if (recursive)
|
||||
comment.children = getChildren($(this).find('#cl' + comment.id), true);
|
||||
children.push(comment);
|
||||
});
|
||||
return children;
|
||||
}
|
||||
|
||||
/** Create a div to display a comment in. */
|
||||
function createCommentDiv(comment) {
|
||||
if (!comment.displayed && !opts.moderator) {
|
||||
return $('<div class="moderate">Thank you! Your comment will show up '
|
||||
+ 'once it is has been approved by a moderator.</div>');
|
||||
}
|
||||
// Prettify the comment rating.
|
||||
comment.pretty_rating = comment.rating + ' point' +
|
||||
(comment.rating == 1 ? '' : 's');
|
||||
// Make a class (for displaying not yet moderated comments differently)
|
||||
comment.css_class = comment.displayed ? '' : ' moderate';
|
||||
// Create a div for this comment.
|
||||
var context = $.extend({}, opts, comment);
|
||||
var div = $(renderTemplate(commentTemplate, context));
|
||||
|
||||
// If the user has voted on this comment, highlight the correct arrow.
|
||||
if (comment.vote) {
|
||||
var direction = (comment.vote == 1) ? 'u' : 'd';
|
||||
div.find('#' + direction + 'v' + comment.id).hide();
|
||||
div.find('#' + direction + 'u' + comment.id).show();
|
||||
}
|
||||
|
||||
if (opts.moderator || comment.text != '[deleted]') {
|
||||
div.find('a.reply').show();
|
||||
if (comment.proposal_diff)
|
||||
div.find('#sp' + comment.id).show();
|
||||
if (opts.moderator && !comment.displayed)
|
||||
div.find('#cm' + comment.id).show();
|
||||
if (opts.moderator || (opts.username == comment.username))
|
||||
div.find('#dc' + comment.id).show();
|
||||
}
|
||||
return div;
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple template renderer. Placeholders such as <%id%> are replaced
|
||||
* by context['id'] with items being escaped. Placeholders such as <#id#>
|
||||
* are not escaped.
|
||||
*/
|
||||
function renderTemplate(template, context) {
|
||||
var esc = $(document.createElement('div'));
|
||||
|
||||
function handle(ph, escape) {
|
||||
var cur = context;
|
||||
$.each(ph.split('.'), function() {
|
||||
cur = cur[this];
|
||||
});
|
||||
return escape ? esc.text(cur || "").html() : cur;
|
||||
}
|
||||
|
||||
return template.replace(/<([%#])([\w\.]*)\1>/g, function() {
|
||||
return handle(arguments[2], arguments[1] == '%' ? true : false);
|
||||
});
|
||||
}
|
||||
|
||||
/** Flash an error message briefly. */
|
||||
function showError(message) {
|
||||
$(document.createElement('div')).attr({'class': 'popup-error'})
|
||||
.append($(document.createElement('div'))
|
||||
.attr({'class': 'error-message'}).text(message))
|
||||
.appendTo('body')
|
||||
.fadeIn("slow")
|
||||
.delay(2000)
|
||||
.fadeOut("slow");
|
||||
}
|
||||
|
||||
/** Add a link the user uses to open the comments popup. */
|
||||
$.fn.comment = function() {
|
||||
return this.each(function() {
|
||||
var id = $(this).attr('id').substring(1);
|
||||
var count = COMMENT_METADATA[id];
|
||||
var title = count + ' comment' + (count == 1 ? '' : 's');
|
||||
var image = count > 0 ? opts.commentBrightImage : opts.commentImage;
|
||||
var addcls = count == 0 ? ' nocomment' : '';
|
||||
$(this)
|
||||
.append(
|
||||
$(document.createElement('a')).attr({
|
||||
href: '#',
|
||||
'class': 'sphinx-comment-open' + addcls,
|
||||
id: 'ao' + id
|
||||
})
|
||||
.append($(document.createElement('img')).attr({
|
||||
src: image,
|
||||
alt: 'comment',
|
||||
title: title
|
||||
}))
|
||||
.click(function(event) {
|
||||
event.preventDefault();
|
||||
show($(this).attr('id').substring(2));
|
||||
})
|
||||
)
|
||||
.append(
|
||||
$(document.createElement('a')).attr({
|
||||
href: '#',
|
||||
'class': 'sphinx-comment-close hidden',
|
||||
id: 'ah' + id
|
||||
})
|
||||
.append($(document.createElement('img')).attr({
|
||||
src: opts.closeCommentImage,
|
||||
alt: 'close',
|
||||
title: 'close'
|
||||
}))
|
||||
.click(function(event) {
|
||||
event.preventDefault();
|
||||
hide($(this).attr('id').substring(2));
|
||||
})
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
var opts = {
|
||||
processVoteURL: '/_process_vote',
|
||||
addCommentURL: '/_add_comment',
|
||||
getCommentsURL: '/_get_comments',
|
||||
acceptCommentURL: '/_accept_comment',
|
||||
deleteCommentURL: '/_delete_comment',
|
||||
commentImage: '/static/_static/comment.png',
|
||||
closeCommentImage: '/static/_static/comment-close.png',
|
||||
loadingImage: '/static/_static/ajax-loader.gif',
|
||||
commentBrightImage: '/static/_static/comment-bright.png',
|
||||
upArrow: '/static/_static/up.png',
|
||||
downArrow: '/static/_static/down.png',
|
||||
upArrowPressed: '/static/_static/up-pressed.png',
|
||||
downArrowPressed: '/static/_static/down-pressed.png',
|
||||
voting: false,
|
||||
moderator: false
|
||||
};
|
||||
|
||||
if (typeof COMMENT_OPTIONS != "undefined") {
|
||||
opts = jQuery.extend(opts, COMMENT_OPTIONS);
|
||||
}
|
||||
|
||||
var popupTemplate = '\
|
||||
<div class="sphinx-comments" id="sc<%id%>">\
|
||||
<p class="sort-options">\
|
||||
Sort by:\
|
||||
<a href="#" class="sort-option byrating">best rated</a>\
|
||||
<a href="#" class="sort-option byascage">newest</a>\
|
||||
<a href="#" class="sort-option byage">oldest</a>\
|
||||
</p>\
|
||||
<div class="comment-header">Comments</div>\
|
||||
<div class="comment-loading" id="cn<%id%>">\
|
||||
loading comments... <img src="<%loadingImage%>" alt="" /></div>\
|
||||
<ul id="cl<%id%>" class="comment-ul"></ul>\
|
||||
<div id="ca<%id%>">\
|
||||
<p class="add-a-comment">Add a comment\
|
||||
(<a href="#" class="comment-markup" id="ab<%id%>">markup</a>):</p>\
|
||||
<div class="comment-markup-box" id="mb<%id%>">\
|
||||
reStructured text markup: <i>*emph*</i>, <b>**strong**</b>, \
|
||||
<code>``code``</code>, \
|
||||
code blocks: <code>::</code> and an indented block after blank line</div>\
|
||||
<form method="post" id="cf<%id%>" class="comment-form" action="">\
|
||||
<textarea name="comment" cols="80"></textarea>\
|
||||
<p class="propose-button">\
|
||||
<a href="#" id="pc<%id%>" class="show-propose-change">\
|
||||
Propose a change ▹\
|
||||
</a>\
|
||||
<a href="#" id="hc<%id%>" class="hide-propose-change">\
|
||||
Propose a change ▿\
|
||||
</a>\
|
||||
</p>\
|
||||
<textarea name="proposal" id="pt<%id%>" cols="80"\
|
||||
spellcheck="false"></textarea>\
|
||||
<input type="submit" value="Add comment" />\
|
||||
<input type="hidden" name="node" value="<%id%>" />\
|
||||
<input type="hidden" name="parent" value="" />\
|
||||
</form>\
|
||||
</div>\
|
||||
</div>';
|
||||
|
||||
var commentTemplate = '\
|
||||
<div id="cd<%id%>" class="sphinx-comment<%css_class%>">\
|
||||
<div class="vote">\
|
||||
<div class="arrow">\
|
||||
<a href="#" id="uv<%id%>" class="vote" title="vote up">\
|
||||
<img src="<%upArrow%>" />\
|
||||
</a>\
|
||||
<a href="#" id="uu<%id%>" class="un vote" title="vote up">\
|
||||
<img src="<%upArrowPressed%>" />\
|
||||
</a>\
|
||||
</div>\
|
||||
<div class="arrow">\
|
||||
<a href="#" id="dv<%id%>" class="vote" title="vote down">\
|
||||
<img src="<%downArrow%>" id="da<%id%>" />\
|
||||
</a>\
|
||||
<a href="#" id="du<%id%>" class="un vote" title="vote down">\
|
||||
<img src="<%downArrowPressed%>" />\
|
||||
</a>\
|
||||
</div>\
|
||||
</div>\
|
||||
<div class="comment-content">\
|
||||
<p class="tagline comment">\
|
||||
<span class="user-id"><%username%></span>\
|
||||
<span class="rating"><%pretty_rating%></span>\
|
||||
<span class="delta"><%time.delta%></span>\
|
||||
</p>\
|
||||
<div class="comment-text comment"><#text#></div>\
|
||||
<p class="comment-opts comment">\
|
||||
<a href="#" class="reply hidden" id="rl<%id%>">reply ▹</a>\
|
||||
<a href="#" class="close-reply" id="cr<%id%>">reply ▿</a>\
|
||||
<a href="#" id="sp<%id%>" class="show-proposal">proposal ▹</a>\
|
||||
<a href="#" id="hp<%id%>" class="hide-proposal">proposal ▿</a>\
|
||||
<a href="#" id="dc<%id%>" class="delete-comment hidden">delete</a>\
|
||||
<span id="cm<%id%>" class="moderation hidden">\
|
||||
<a href="#" id="ac<%id%>" class="accept-comment">accept</a>\
|
||||
</span>\
|
||||
</p>\
|
||||
<pre class="proposal" id="pr<%id%>">\
|
||||
<#proposal_diff#>\
|
||||
</pre>\
|
||||
<ul class="comment-children" id="cl<%id%>"></ul>\
|
||||
</div>\
|
||||
<div class="clearleft"></div>\
|
||||
</div>\
|
||||
</div>';
|
||||
|
||||
var replyTemplate = '\
|
||||
<li>\
|
||||
<div class="reply-div" id="rd<%id%>">\
|
||||
<form id="rf<%id%>">\
|
||||
<textarea name="comment" cols="80"></textarea>\
|
||||
<input type="submit" value="Add reply" />\
|
||||
<input type="button" value="Cancel" />\
|
||||
<input type="hidden" name="parent" value="<%id%>" />\
|
||||
<input type="hidden" name="node" value="" />\
|
||||
</form>\
|
||||
</div>\
|
||||
</li>';
|
||||
|
||||
$(document).ready(function() {
|
||||
init();
|
||||
});
|
||||
})(jQuery);
|
||||
|
||||
$(document).ready(function() {
|
||||
// add comment anchors for all paragraphs that are commentable
|
||||
$('.sphinx-has-comment').comment();
|
||||
|
||||
// highlight search words in search results
|
||||
$("div.context").each(function() {
|
||||
var params = $.getQueryParameters();
|
||||
var terms = (params.q) ? params.q[0].split(/\s+/) : [];
|
||||
var result = $(this);
|
||||
$.each(terms, function() {
|
||||
result.highlightText(this.toLowerCase(), 'highlighted');
|
||||
});
|
||||
});
|
||||
|
||||
// directly open comment window if requested
|
||||
var anchor = document.location.hash;
|
||||
if (anchor.substring(0, 9) == '#comment-') {
|
||||
$('#ao' + anchor.substring(9)).click();
|
||||
document.location.hash = '#s' + anchor.substring(9);
|
||||
}
|
||||
});
|
@@ -13,6 +13,10 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
@@ -32,9 +36,7 @@
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="_static/js/theme.js"></script>
|
||||
|
||||
@@ -85,15 +87,10 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
@@ -7,12 +7,16 @@
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>From Source — Triton documentation</title>
|
||||
<title>Installation — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
@@ -32,17 +36,15 @@
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="Vector Addition" href="../tutorials/01-vector-add.html" />
|
||||
<link rel="prev" title="Packaged Binaries" href="packaged-binaries.html" />
|
||||
<link rel="next" title="Tutorials" href="tutorials/index.html" />
|
||||
<link rel="prev" title="Welcome to Triton’s documentation!" href="../index.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
@@ -87,19 +89,18 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">From Source</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#python-package">Python Package</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#c-package">C++ Package</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Installation</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#with-pip">With Pip</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#from-source">From Source</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#python-package">Python Package</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#c-package">C++ Package</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -148,13 +149,13 @@
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>From Source</li>
|
||||
<li>Installation</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/installation/from-source.rst.txt" rel="nofollow"> View page source</a>
|
||||
<a href="../_sources/getting-started/installation.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
@@ -167,34 +168,19 @@
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="installation">
|
||||
<h1>Installation<a class="headerlink" href="#installation" title="Permalink to this headline">¶</a></h1>
|
||||
<div class="section" id="with-pip">
|
||||
<h2>With Pip<a class="headerlink" href="#with-pip" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Triton can be installed directly from pip with the following command</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pip</span> <span class="n">install</span> <span class="n">triton</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="from-source">
|
||||
<h1>From Source<a class="headerlink" href="#from-source" title="Permalink to this headline">¶</a></h1>
|
||||
<h2>From Source<a class="headerlink" href="#from-source" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="section" id="python-package">
|
||||
<h2>Python Package<a class="headerlink" href="#python-package" title="Permalink to this headline">¶</a></h2>
|
||||
<h3>Python Package<a class="headerlink" href="#python-package" title="Permalink to this headline">¶</a></h3>
|
||||
<p>You can install the Python package from source by running the following commands:</p>
|
||||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>sudo apt-get install llvm-10-dev
|
||||
git clone https://github.com/ptillet/triton.git<span class="p">;</span>
|
||||
@@ -213,7 +199,7 @@ python -m run --with-plots --result-dir /tmp/triton-bench
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="c-package">
|
||||
<h2>C++ Package<a class="headerlink" href="#c-package" title="Permalink to this headline">¶</a></h2>
|
||||
<h3>C++ Package<a class="headerlink" href="#c-package" title="Permalink to this headline">¶</a></h3>
|
||||
<p>Those not interested in Python integration may want to use the internals of Triton (i.e, runtime, parser, codegen, driver, intermediate representation) directly. This can be done by running the following commands:</p>
|
||||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>sudo apt-get install llvm-10-dev
|
||||
git clone https://github.com/ptillet/triton.git<span class="p">;</span>
|
||||
@@ -229,6 +215,7 @@ make -j8<span class="p">;</span>
|
||||
</div>
|
||||
<p>Note that while direct usage of the C++ API is not officially supported, a usage tutorial can be found <a class="reference external" href="https://github.com/ptillet/triton/blob/master/tutorials/01-matmul.cc">here</a></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -237,8 +224,8 @@ make -j8<span class="p">;</span>
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="../tutorials/01-vector-add.html" class="btn btn-neutral float-right" title="Vector Addition" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="packaged-binaries.html" class="btn btn-neutral float-left" title="Packaged Binaries" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
<a href="tutorials/index.html" class="btn btn-neutral float-right" title="Tutorials" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../index.html" class="btn btn-neutral float-left" title="Welcome to Triton’s documentation!" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
407
getting-started/tutorials/01-vector-add.html
Normal file
@@ -0,0 +1,407 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Vector Addition — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
|
||||
<link rel="prev" title="Tutorials" href="index.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch bindings</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#benchmarking">Benchmarking</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li><a href="index.html">Tutorials</a> »</li>
|
||||
|
||||
<li>Vector Addition</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="sphx-glr-download-link-note admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
|
||||
to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="vector-addition">
|
||||
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a simple, high-performance vector addition using Triton and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The basic syntax of the Triton programming language</p></li>
|
||||
<li><p>The best practices for creating PyTorch custom operators using the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API</p></li>
|
||||
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
|
||||
</ul>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Each compute kernel is declared using the <code class="code docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel
|
||||
on different chunks of data (See the <a class="reference external" href="(https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a>)
|
||||
programming model for more details).</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
|
||||
<span class="c1">// of the program in the overaching SPMD context</span>
|
||||
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
|
||||
<span class="c1">// different chunks of data in parallel.</span>
|
||||
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
|
||||
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
|
||||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
|
||||
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
|
||||
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
|
||||
<span class="c1">// memory.</span>
|
||||
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
|
||||
<span class="c1">// whose memory values are, e.g.:</span>
|
||||
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
|
||||
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
|
||||
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
|
||||
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
|
||||
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
|
||||
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
|
||||
<span class="c1">// write-back `z`</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL’2019 Triton paper</a>.</p>
|
||||
</div>
|
||||
<div class="section" id="torch-bindings">
|
||||
<h2>Torch bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p>
|
||||
<ul class="simple">
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li>
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for</p></li>
|
||||
<li><p><code class="code docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <cite>#define</cite> for you</p></li>
|
||||
</ul>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
|
||||
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
|
||||
<span class="s2"> // program id</span>
|
||||
<span class="s2"> int pid = get_program_id(0);</span>
|
||||
<span class="s2"> // create arrays of pointers</span>
|
||||
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
|
||||
<span class="s2"> float* px[BLOCK] = x + offset;</span>
|
||||
<span class="s2"> float* py[BLOCK] = y + offset;</span>
|
||||
<span class="s2"> // bounds checking</span>
|
||||
<span class="s2"> bool check[BLOCK] = offset < N;</span>
|
||||
<span class="s2"> // write-back</span>
|
||||
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2"> """</span>
|
||||
|
||||
|
||||
<span class="c1"># This function returns a callable `triton.kernel` object created from the above source code.</span>
|
||||
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
|
||||
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
|
||||
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span>
|
||||
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
|
||||
<span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
|
||||
|
||||
<span class="c1"># This is a standard torch custom autograd Function;</span>
|
||||
<span class="c1"># The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`</span>
|
||||
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="c1"># *allocate output*</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
|
||||
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
|
||||
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
|
||||
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
|
||||
<span class="c1"># the `.data_ptr()` method</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
|
||||
|
||||
<span class="c1"># Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function</span>
|
||||
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
The maximum difference between torch and triton is 0.0
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="benchmarking">
|
||||
<h2>Benchmarking<a class="headerlink" href="#benchmarking" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span>
|
||||
<span class="n">rep</span> <span class="o">=</span> <span class="mi">200</span>
|
||||
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">),</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
|
||||
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="n">rep</span><span class="p">)</span>
|
||||
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>131072 0.022 0.006
|
||||
262144 0.021 0.005
|
||||
524288 0.022 0.017
|
||||
1048576 0.037 0.037
|
||||
2097152 0.074 0.073
|
||||
4194304 0.144 0.143
|
||||
8388608 0.289 0.285
|
||||
16777216 0.566 0.562
|
||||
33554432 1.131 1.121
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 3.225 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
</div>
|
||||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
450
getting-started/tutorials/02-fused-softmax.html
Normal file
@@ -0,0 +1,450 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Fused Softmax — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
|
||||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#benchmarking">Benchmarking</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li><a href="index.html">Tutorials</a> »</li>
|
||||
|
||||
<li>Fused Softmax</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/02-fused-softmax.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="sphx-glr-download-link-note admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">here</span></a>
|
||||
to download the full example code</p>
|
||||
</div>
|
||||
<div class="sphx-glr-example-title section" id="fused-softmax">
|
||||
<span id="sphx-glr-getting-started-tutorials-02-fused-softmax-py"></span><h1>Fused Softmax<a class="headerlink" href="#fused-softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, you will write a fused softmax layer that outperform’s PyTorch implementation and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||||
<li><p>The syntax and usage of reduction operators in Triton.</p></li>
|
||||
<li><p>The automatic vectorization capabilities of the Triton compiler.</p></li>
|
||||
</ul>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice.
|
||||
Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
|
||||
<span class="c1"># Compute the row-wise softmax of x</span>
|
||||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># read MN elements ; write MN elements</span>
|
||||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
|
||||
<span class="k">return</span> <span class="n">ret</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
Instead, we want to write a custom “fused” pytorch operators that only reads X once and does all the necessary computations on-chip.
|
||||
This would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of 5x.
|
||||
In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.</p>
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// row index</span>
|
||||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// column indices</span>
|
||||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="c1">// the memory address of all the elements</span>
|
||||
<span class="c1">// that we want to load can be computed as follows</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||||
<span class="c1">// to guard each memory operation with predicates</span>
|
||||
<span class="c1">// or we will read out of bounds</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||||
<span class="c1">// x[..., OPERATOR, ...]</span>
|
||||
<span class="c1">// ^</span>
|
||||
<span class="c1">// index</span>
|
||||
<span class="c1">// The operators currently supported are {min, max, +}</span>
|
||||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||||
<span class="c1">// The exponential in Triton is fast but approximate</span>
|
||||
<span class="c1">// (i.e., like __expf in CUDA)</span>
|
||||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||||
<span class="c1">// We write it back</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
</div>
|
||||
<div class="section" id="torch-bindings">
|
||||
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>We need to make sure that BLOCK is the smallest power of two
|
||||
greater than the number of rows N of the input matrix.
|
||||
Different values of BLOCK will result in different kernels</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># Source code for the Triton kernel</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||||
<span class="s2"> int m = get_program_id(0);</span>
|
||||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||||
<span class="s2"> float denom = num[+];</span>
|
||||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||||
<span class="s2"> *?(check)py = y;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">4</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">8</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">16</span>
|
||||
<span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">return</span> <span class="n">n</span>
|
||||
|
||||
|
||||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||||
<span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># here we just launch a grid of M programs</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
|
||||
|
||||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_tri</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_ref</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[2.0935e-03, 6.4551e-04, 9.8605e-05, ..., 3.3981e-04, 2.7386e-03,
|
||||
9.1986e-05],
|
||||
[7.0923e-04, 6.7521e-04, 5.1366e-04, ..., 9.8392e-04, 2.6547e-04,
|
||||
6.9062e-04],
|
||||
[1.4032e-04, 5.8826e-04, 1.1694e-03, ..., 6.6423e-04, 1.8178e-04,
|
||||
6.7049e-04],
|
||||
...,
|
||||
[1.1767e-03, 4.2703e-03, 6.0596e-04, ..., 9.5274e-04, 1.1681e-03,
|
||||
6.4924e-04],
|
||||
[1.0772e-04, 7.4854e-04, 3.1912e-03, ..., 2.4980e-04, 1.9012e-03,
|
||||
5.2567e-04],
|
||||
[2.8518e-03, 8.1899e-04, 7.7046e-04, ..., 1.3403e-03, 5.3167e-04,
|
||||
4.3268e-04]], device='cuda:0')
|
||||
tensor([[2.0935e-03, 6.4551e-04, 9.8605e-05, ..., 3.3981e-04, 2.7386e-03,
|
||||
9.1986e-05],
|
||||
[7.0923e-04, 6.7521e-04, 5.1366e-04, ..., 9.8392e-04, 2.6547e-04,
|
||||
6.9062e-04],
|
||||
[1.4032e-04, 5.8826e-04, 1.1694e-03, ..., 6.6423e-04, 1.8178e-04,
|
||||
6.7049e-04],
|
||||
...,
|
||||
[1.1767e-03, 4.2703e-03, 6.0596e-04, ..., 9.5274e-04, 1.1681e-03,
|
||||
6.4924e-04],
|
||||
[1.0772e-04, 7.4854e-04, 3.1912e-03, ..., 2.4980e-04, 1.9012e-03,
|
||||
5.2567e-04],
|
||||
[2.8518e-03, 8.1899e-04, 7.7046e-04, ..., 1.3403e-03, 5.3167e-04,
|
||||
4.3268e-04]], device='cuda:0')
|
||||
True
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Seems to work!</p>
|
||||
</div>
|
||||
<div class="section" id="benchmarking">
|
||||
<h2>Benchmarking<a class="headerlink" href="#benchmarking" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
|
||||
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||||
<span class="n">Ns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)]</span>
|
||||
<span class="n">tri_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">def_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">Ns</span><span class="p">:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="n">tri_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))]</span>
|
||||
<span class="n">def_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'N'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'Bandwidth (GB/s)'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">tri_ms</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Triton'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">ref_ms</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Torch'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">def_ms</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">'Naive'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="02 fused softmax" class="sphx-glr-single-img" src="../../_images/sphx_glr_02-fused-softmax_001.png" />
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 5.758 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
</div>
|
||||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/034d953b6214fedce6ea03803c712b89/02-fused-softmax.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">02-fused-softmax.ipynb</span></code></a></p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
241
getting-started/tutorials/index.html
Normal file
@@ -0,0 +1,241 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Tutorials — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="Vector Addition" href="01-vector-add.html" />
|
||||
<link rel="prev" title="Installation" href="../installation.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorials</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Tutorials</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/index.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="section" id="tutorials">
|
||||
<span id="sphx-glr-getting-started-tutorials"></span><h1>Tutorials<a class="headerlink" href="#tutorials" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one.</p>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The basic syntax of the Triton programming language - The best practices for creating PyTorch..."><div class="figure align-default" id="id1">
|
||||
<img alt="Vector Addition" src="../../_images/sphx_glr_01-vector-add_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a></span><a class="headerlink" href="#id1" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="- The benefits of kernel fusion for bandwidth-bound operations. - The syntax and usage of reduc..."><div class="figure align-default" id="id2">
|
||||
<img alt="Fused Softmax" src="../../_images/sphx_glr_02-fused-softmax_thumb.png" />
|
||||
<p class="caption"><span class="caption-text"><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a></span><a class="headerlink" href="#id2" title="Permalink to this image">¶</a></p>
|
||||
</div>
|
||||
</div><div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="sphx-glr-clear"></div><div class="sphx-glr-footer class sphx-glr-footer-gallery docutils container">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">tutorials_python.zip</span></code></a></p>
|
||||
</div>
|
||||
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">all</span> <span class="pre">examples</span> <span class="pre">in</span> <span class="pre">Jupyter</span> <span class="pre">notebooks:</span> <span class="pre">tutorials_jupyter.zip</span></code></a></p>
|
||||
</div>
|
||||
</div>
|
||||
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="01-vector-add.html" class="btn btn-neutral float-right" title="Vector Addition" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../installation.html" class="btn btn-neutral float-left" title="Installation" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
227
getting-started/tutorials/sg_execution_times.html
Normal file
@@ -0,0 +1,227 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Computation times — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
|
||||
<script src="../../_static/jquery.js"></script>
|
||||
<script src="../../_static/underscore.js"></script>
|
||||
<script src="../../_static/doctools.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
|
||||
<script type="text/javascript" src="../../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Computation times</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../../_sources/getting-started/tutorials/sg_execution_times.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>00:08.983</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 82%" />
|
||||
<col style="width: 10%" />
|
||||
<col style="width: 7%" />
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>00:05.758</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:03.225</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
65
index.html
@@ -13,6 +13,10 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
@@ -32,16 +36,14 @@
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="genindex.html" />
|
||||
<link rel="search" title="Search" href="search.html" />
|
||||
<link rel="next" title="Packaged Binaries" href="installation/packaged-binaries.html" />
|
||||
<link rel="next" title="Installation" href="getting-started/installation.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
@@ -86,15 +88,10 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -162,45 +159,17 @@
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="welcome-to-triton-s-documentation">
|
||||
<div class="section" id="welcome-to-triton-s-documentation">
|
||||
<h1>Welcome to Triton’s documentation!<a class="headerlink" href="#welcome-to-triton-s-documentation" title="Permalink to this headline">¶</a></h1>
|
||||
<div class="toctree-wrapper compound">
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/from-source.html">From Source</a></li>
|
||||
<p>Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.</p>
|
||||
<div class="section" id="getting-started">
|
||||
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Permalink to this headline">¶</a></h2>
|
||||
<ul class="simple">
|
||||
<li><p>Follow the <a class="reference internal" href="getting-started/installation.html"><span class="doc">installation instructions</span></a> for your platform of choice.</p></li>
|
||||
<li><p>Take a look at the <a class="reference internal" href="getting-started/tutorials/index.html"><span class="doc">tutorials</span></a> to learn how to write your first Triton program.</p></li>
|
||||
</ul>
|
||||
<div class="toctree-wrapper compound">
|
||||
</div>
|
||||
<div class="toctree-wrapper compound">
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -210,7 +179,7 @@
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="installation/packaged-binaries.html" class="btn btn-neutral float-right" title="Packaged Binaries" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="getting-started/installation.html" class="btn btn-neutral float-right" title="Installation" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
@@ -1,246 +0,0 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Packaged Binaries — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="From Source" href="from-source.html" />
|
||||
<link rel="prev" title="Welcome to Triton’s documentation!" href="../index.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Packaged Binaries</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/installation/packaged-binaries.rst.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="packaged-binaries">
|
||||
<h1>Packaged Binaries<a class="headerlink" href="#packaged-binaries" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Triton can be installed directly from pip with the following command</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pip</span> <span class="n">install</span> <span class="n">triton</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="from-source.html" class="btn btn-neutral float-right" title="From Source" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../index.html" class="btn btn-neutral float-left" title="Welcome to Triton’s documentation!" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
BIN
objects.inv
17
search.html
@@ -13,6 +13,10 @@
|
||||
|
||||
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-binder.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-dataframe.css" type="text/css" />
|
||||
<link rel="stylesheet" href="_static/gallery-rendered-html.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
@@ -33,9 +37,7 @@
|
||||
<script src="_static/jquery.js"></script>
|
||||
<script src="_static/underscore.js"></script>
|
||||
<script src="_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="_static/js/theme.js"></script>
|
||||
|
||||
@@ -88,15 +90,10 @@
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<p class="caption"><span class="caption-text">Getting Started</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/02-fused-softmax.html">Fused Softmax</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="getting-started/tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
@@ -1,697 +0,0 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Vector Addition — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
|
||||
<link rel="prev" title="From Source" href="../installation/from-source.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Vector Addition</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/tutorials/01-vector-add.ipynb.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS for nbsphinx extension */
|
||||
|
||||
/* remove conflicting styling from Sphinx themes */
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt *,
|
||||
div.nbinput.container div.input_area pre,
|
||||
div.nboutput.container div.output_area pre,
|
||||
div.nbinput.container div.input_area .highlight,
|
||||
div.nboutput.container div.output_area .highlight {
|
||||
border: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
div.nbinput.container > div[class*=highlight],
|
||||
div.nboutput.container > div[class*=highlight] {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt * {
|
||||
background: none;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area .highlight,
|
||||
div.nboutput.container div.output_area pre {
|
||||
background: unset;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area div.highlight {
|
||||
color: unset; /* override Pygments text color */
|
||||
}
|
||||
|
||||
/* avoid gaps between output lines */
|
||||
div.nboutput.container div[class*=highlight] pre {
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
/* input/output containers */
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
display: -webkit-flex;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
|
||||
/* input container */
|
||||
div.nbinput.container {
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
/* last container */
|
||||
div.nblast.container {
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
/* input prompt */
|
||||
div.nbinput.container div.prompt pre {
|
||||
color: #307FC1;
|
||||
}
|
||||
|
||||
/* output prompt */
|
||||
div.nboutput.container div.prompt pre {
|
||||
color: #BF5B3D;
|
||||
}
|
||||
|
||||
/* all prompts */
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: 4.5ex;
|
||||
padding-top: 5px;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
margin-right: 0.3ex;
|
||||
}
|
||||
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: unset;
|
||||
text-align: left;
|
||||
padding: 0.4em;
|
||||
}
|
||||
div.nboutput.container div.prompt.empty {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: unset;
|
||||
}
|
||||
}
|
||||
|
||||
/* disable scrollbars on prompts */
|
||||
div.nbinput.container div.prompt pre,
|
||||
div.nboutput.container div.prompt pre {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* input/output area */
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
-webkit-flex: 1;
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
/* input area */
|
||||
div.nbinput.container div.input_area {
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 2px;
|
||||
/*background: #f5f5f5;*/
|
||||
}
|
||||
|
||||
/* override MathJax center alignment in output cells */
|
||||
div.nboutput.container div[class*=MathJax] {
|
||||
text-align: left !important;
|
||||
}
|
||||
|
||||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||||
div.nboutput.container div.math p {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
/* standard error */
|
||||
div.nboutput.container div.output_area.stderr {
|
||||
background: #fdd;
|
||||
}
|
||||
|
||||
/* ANSI colors */
|
||||
.ansi-black-fg { color: #3E424D; }
|
||||
.ansi-black-bg { background-color: #3E424D; }
|
||||
.ansi-black-intense-fg { color: #282C36; }
|
||||
.ansi-black-intense-bg { background-color: #282C36; }
|
||||
.ansi-red-fg { color: #E75C58; }
|
||||
.ansi-red-bg { background-color: #E75C58; }
|
||||
.ansi-red-intense-fg { color: #B22B31; }
|
||||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||||
.ansi-green-fg { color: #00A250; }
|
||||
.ansi-green-bg { background-color: #00A250; }
|
||||
.ansi-green-intense-fg { color: #007427; }
|
||||
.ansi-green-intense-bg { background-color: #007427; }
|
||||
.ansi-yellow-fg { color: #DDB62B; }
|
||||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||||
.ansi-blue-fg { color: #208FFB; }
|
||||
.ansi-blue-bg { background-color: #208FFB; }
|
||||
.ansi-blue-intense-fg { color: #0065CA; }
|
||||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||||
.ansi-magenta-fg { color: #D160C4; }
|
||||
.ansi-magenta-bg { background-color: #D160C4; }
|
||||
.ansi-magenta-intense-fg { color: #A03196; }
|
||||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||||
.ansi-cyan-fg { color: #60C6C8; }
|
||||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||||
.ansi-white-fg { color: #C5C1B4; }
|
||||
.ansi-white-bg { background-color: #C5C1B4; }
|
||||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||||
|
||||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||||
.ansi-default-inverse-bg { background-color: #000000; }
|
||||
|
||||
.ansi-bold { font-weight: bold; }
|
||||
.ansi-underline { text-decoration: underline; }
|
||||
|
||||
|
||||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||||
div.nboutput.container div.output_area.rendered_html,
|
||||
div.nboutput.container div.output_area > div.output_javascript,
|
||||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||||
padding: 5px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||||
.prompt a.copybtn {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||||
div.rendered_html table {
|
||||
border: none;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
div.rendered_html thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
div.rendered_html tr,
|
||||
div.rendered_html th,
|
||||
div.rendered_html td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
div.rendered_html th {
|
||||
font-weight: bold;
|
||||
}
|
||||
div.rendered_html tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
div.rendered_html tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
||||
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="Vector-Addition">
|
||||
<h1>Vector Addition<a class="headerlink" href="#Vector-Addition" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API * The best practices for validating and benchmarking custom ops against native reference implementations</p>
|
||||
<div class="section" id="Writing-the-Compute-Kernel">
|
||||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Each compute kernel is declared using the <code class="docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel on different chunks of data (See the <a class="reference external" href="https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a> programming model for more details).</p>
|
||||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
|
||||
<span class="c1">// of the program in the overaching SPMD context</span>
|
||||
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
|
||||
<span class="c1">// different chunks of data in parallel.</span>
|
||||
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
|
||||
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
|
||||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
|
||||
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
|
||||
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
|
||||
<span class="c1">// memory.</span>
|
||||
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
|
||||
<span class="c1">// whose memory values are, e.g.:</span>
|
||||
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
|
||||
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
|
||||
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
|
||||
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
|
||||
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
|
||||
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
|
||||
<span class="c1">// write-back `z`</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL’2019 Triton paper</a>.</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-the-Torch-bindings">
|
||||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The only thing that matters when it comes to Triton and Torch is the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects.</p>
|
||||
<p>To create a <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things: * <code class="docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create * <code class="docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for * <code class="docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <code class="docutils literal notranslate"><span class="pre">#define</span></code> for you</p>
|
||||
<p>Note: The constructor of <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see <code class="docutils literal notranslate"><span class="pre">_kernels</span></code> variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator’s inputs.</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
|
||||
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
|
||||
<span class="s2"> // program id</span>
|
||||
<span class="s2"> int pid = get_program_id(0);</span>
|
||||
<span class="s2"> // create arrays of pointers</span>
|
||||
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
|
||||
<span class="s2"> float* px[BLOCK] = x + offset;</span>
|
||||
<span class="s2"> float* py[BLOCK] = y + offset;</span>
|
||||
<span class="s2"> // bounds checking</span>
|
||||
<span class="s2"> bool check[BLOCK] = offset < N;</span>
|
||||
<span class="s2"> // write-back</span>
|
||||
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2"> """</span>
|
||||
<span class="c1"># This function returns a callable `triton.kernel` object</span>
|
||||
<span class="c1"># created from the above source code.</span>
|
||||
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
|
||||
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
|
||||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
|
||||
<span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># This is a standard torch custom autograd Function</span>
|
||||
<span class="c1"># The only difference is that we can now use the above kernel</span>
|
||||
<span class="c1"># in the `forward` and `backward` functions.`</span>
|
||||
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="c1"># *allocate output*</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
|
||||
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
|
||||
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
|
||||
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
|
||||
<span class="c1"># the `.data_ptr()` method</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
<span class="c1"># Just like we standard PyTorch ops</span>
|
||||
<span class="c1"># We use the `.apply` method to create a</span>
|
||||
<span class="c1"># callable object for our function</span>
|
||||
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>At this point <code class="docutils literal notranslate"><span class="pre">add(x,</span> <span class="pre">y)</span></code> is equivalent to <code class="docutils literal notranslate"><span class="pre">x</span> <span class="pre">+</span> <span class="pre">y</span></code> for contiguous tensors. Now let’s test and benchmark it!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Unit-Test">
|
||||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
The maximum difference between torch and triton is 0.0
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Seems to work!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Benchmark">
|
||||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The performance of our GPU code can be benchmark using the <code class="docutils literal notranslate"><span class="pre">torch.cuda.Event(enable_timing=True)</span></code> wrapper. Below is a simple function that benchmarks <code class="docutils literal notranslate"><span class="pre">rep</span></code> runs of our kernels after <code class="docutils literal notranslate"><span class="pre">warmup</span></code> “cold” runs.</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="c1"># We now want to benchmark the performance of `add`</span>
|
||||
<span class="c1"># Against that of PyTorch for increasing vector sizes</span>
|
||||
<span class="k">def</span> <span class="nf">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">rep</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
|
||||
<span class="n">start_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">end_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">warmup</span><span class="p">):</span>
|
||||
<span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||||
<span class="n">start_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rep</span><span class="p">):</span>
|
||||
<span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">end_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||||
<span class="n">time_ms</span> <span class="o">=</span> <span class="n">start_event</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end_event</span><span class="p">)</span> <span class="o">/</span> <span class="n">rep</span>
|
||||
<span class="k">return</span> <span class="n">time_ms</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[15]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||||
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
131072 0.020 0.003
|
||||
262144 0.019 0.004
|
||||
524288 0.016 0.016
|
||||
1048576 0.033 0.033
|
||||
2097152 0.071 0.070
|
||||
4194304 0.142 0.144
|
||||
8388608 0.287 0.286
|
||||
16777216 0.572 0.568
|
||||
33554432 1.139 1.110
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Our op is on-par with Torch’s vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../installation/from-source.html" class="btn btn-neutral float-left" title="From Source" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
@@ -1,329 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "acute-possession",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Vector Addition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@@ -1,705 +0,0 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Fused Softmax — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Fused Softmax</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/tutorials/02-fused-softmax.ipynb.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS for nbsphinx extension */
|
||||
|
||||
/* remove conflicting styling from Sphinx themes */
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt *,
|
||||
div.nbinput.container div.input_area pre,
|
||||
div.nboutput.container div.output_area pre,
|
||||
div.nbinput.container div.input_area .highlight,
|
||||
div.nboutput.container div.output_area .highlight {
|
||||
border: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
div.nbinput.container > div[class*=highlight],
|
||||
div.nboutput.container > div[class*=highlight] {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt * {
|
||||
background: none;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area .highlight,
|
||||
div.nboutput.container div.output_area pre {
|
||||
background: unset;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area div.highlight {
|
||||
color: unset; /* override Pygments text color */
|
||||
}
|
||||
|
||||
/* avoid gaps between output lines */
|
||||
div.nboutput.container div[class*=highlight] pre {
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
/* input/output containers */
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
display: -webkit-flex;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
|
||||
/* input container */
|
||||
div.nbinput.container {
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
/* last container */
|
||||
div.nblast.container {
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
/* input prompt */
|
||||
div.nbinput.container div.prompt pre {
|
||||
color: #307FC1;
|
||||
}
|
||||
|
||||
/* output prompt */
|
||||
div.nboutput.container div.prompt pre {
|
||||
color: #BF5B3D;
|
||||
}
|
||||
|
||||
/* all prompts */
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: 4.5ex;
|
||||
padding-top: 5px;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
margin-right: 0.3ex;
|
||||
}
|
||||
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: unset;
|
||||
text-align: left;
|
||||
padding: 0.4em;
|
||||
}
|
||||
div.nboutput.container div.prompt.empty {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: unset;
|
||||
}
|
||||
}
|
||||
|
||||
/* disable scrollbars on prompts */
|
||||
div.nbinput.container div.prompt pre,
|
||||
div.nboutput.container div.prompt pre {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* input/output area */
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
-webkit-flex: 1;
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
/* input area */
|
||||
div.nbinput.container div.input_area {
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 2px;
|
||||
/*background: #f5f5f5;*/
|
||||
}
|
||||
|
||||
/* override MathJax center alignment in output cells */
|
||||
div.nboutput.container div[class*=MathJax] {
|
||||
text-align: left !important;
|
||||
}
|
||||
|
||||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||||
div.nboutput.container div.math p {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
/* standard error */
|
||||
div.nboutput.container div.output_area.stderr {
|
||||
background: #fdd;
|
||||
}
|
||||
|
||||
/* ANSI colors */
|
||||
.ansi-black-fg { color: #3E424D; }
|
||||
.ansi-black-bg { background-color: #3E424D; }
|
||||
.ansi-black-intense-fg { color: #282C36; }
|
||||
.ansi-black-intense-bg { background-color: #282C36; }
|
||||
.ansi-red-fg { color: #E75C58; }
|
||||
.ansi-red-bg { background-color: #E75C58; }
|
||||
.ansi-red-intense-fg { color: #B22B31; }
|
||||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||||
.ansi-green-fg { color: #00A250; }
|
||||
.ansi-green-bg { background-color: #00A250; }
|
||||
.ansi-green-intense-fg { color: #007427; }
|
||||
.ansi-green-intense-bg { background-color: #007427; }
|
||||
.ansi-yellow-fg { color: #DDB62B; }
|
||||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||||
.ansi-blue-fg { color: #208FFB; }
|
||||
.ansi-blue-bg { background-color: #208FFB; }
|
||||
.ansi-blue-intense-fg { color: #0065CA; }
|
||||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||||
.ansi-magenta-fg { color: #D160C4; }
|
||||
.ansi-magenta-bg { background-color: #D160C4; }
|
||||
.ansi-magenta-intense-fg { color: #A03196; }
|
||||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||||
.ansi-cyan-fg { color: #60C6C8; }
|
||||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||||
.ansi-white-fg { color: #C5C1B4; }
|
||||
.ansi-white-bg { background-color: #C5C1B4; }
|
||||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||||
|
||||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||||
.ansi-default-inverse-bg { background-color: #000000; }
|
||||
|
||||
.ansi-bold { font-weight: bold; }
|
||||
.ansi-underline { text-decoration: underline; }
|
||||
|
||||
|
||||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||||
div.nboutput.container div.output_area.rendered_html,
|
||||
div.nboutput.container div.output_area > div.output_javascript,
|
||||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||||
padding: 5px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||||
.prompt a.copybtn {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||||
div.rendered_html table {
|
||||
border: none;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
div.rendered_html thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
div.rendered_html tr,
|
||||
div.rendered_html th,
|
||||
div.rendered_html td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
div.rendered_html th {
|
||||
font-weight: bold;
|
||||
}
|
||||
div.rendered_html tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
div.rendered_html tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
||||
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="Fused-Softmax">
|
||||
<h1>Fused Softmax<a class="headerlink" href="#Fused-Softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
<span class="c1"># Compute the row-wise softmax of x \in R^{M \times N}</span>
|
||||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># read MN elements ; write MN elements</span>
|
||||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
|
||||
<span class="k">return</span> <span class="n">ret</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <span class="math notranslate nohighlight">\(y\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.</p>
|
||||
<p>Instead, we want to write a custom “fused” pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.</p>
|
||||
<div class="section" id="Writing-the-Compute-Kernel">
|
||||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// row index</span>
|
||||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// column indices</span>
|
||||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="c1">// the memory address of all the elements</span>
|
||||
<span class="c1">// that we want to load can be computed as follows</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||||
<span class="c1">// to guard each memory operation with predicates</span>
|
||||
<span class="c1">// or we will read out of bounds</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||||
<span class="c1">// x[..., OPERATOR, ...]</span>
|
||||
<span class="c1">// ^</span>
|
||||
<span class="c1">// index</span>
|
||||
<span class="c1">// The operators currently supported are {min, max, +}</span>
|
||||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||||
<span class="c1">// The exponential in Triton is fast but approximate</span>
|
||||
<span class="c1">// (i.e., like __expf in CUDA)</span>
|
||||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||||
<span class="c1">// We write it back</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="Writing-the-Torch-bindings">
|
||||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||||
<span class="s2"> int m = get_program_id(0);</span>
|
||||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||||
<span class="s2"> float denom = num[+];</span>
|
||||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||||
<span class="s2"> *?(check)py = y;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="c1"># We need to make sure that BLOCK is the smallest power of two</span>
|
||||
<span class="c1"># greater than the number of rows N of the input matrix.</span>
|
||||
<span class="c1"># Different values of BLOCK will result in different kernels</span>
|
||||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">4</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">8</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">16</span>
|
||||
<span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">return</span> <span class="n">n</span>
|
||||
|
||||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||||
<span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># here we just launch a grid of M programs</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Unit-Test">
|
||||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_tri</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_ref</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||||
...,
|
||||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||||
device='cuda:0')
|
||||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||||
...,
|
||||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||||
device='cuda:0')
|
||||
True
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Seems to work!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Benchmark">
|
||||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
|
||||
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||||
<span class="n">Ns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="o">*</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)]</span>
|
||||
<span class="n">tri_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">def_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">Ns</span><span class="p">:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="n">tri_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))]</span>
|
||||
<span class="n">def_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'N'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'Bandwidth (GB/s)'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">tri_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Triton'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">ref_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Torch'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">def_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Naive'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<img alt="../_images/tutorials_02-fused-softmax_12_0.png" src="../_images/tutorials_02-fused-softmax_12_0.png" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|