[DOCS] Added non-tutorial documentation pages
This commit is contained in:
@@ -7,3 +7,12 @@ This is the development repository of Triton, a language and compiler for writin
|
||||
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
|
||||
|
||||
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
|
||||
|
||||
# Compatibility
|
||||
|
||||
Supported Platforms:
|
||||
* Linux
|
||||
|
||||
Supported Hardware:
|
||||
* NVIDIA GPUs (Compute Capability 7.0+)
|
||||
* Under development: AMD GPUs, CPUs
|
||||
|
32
docs/conf.py
32
docs/conf.py
@@ -30,19 +30,21 @@
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = ['sphinx.ext.autosectionlabel']
|
||||
autosectionlabel_prefix_document = True
|
||||
extensions = []
|
||||
|
||||
# Math Jax
|
||||
extensions += ['sphinx.ext.mathjax']
|
||||
|
||||
# Sphinx gallery
|
||||
extensions += ['sphinx_gallery.gen_gallery']
|
||||
from sphinx_gallery.sorting import FileNameSortKey
|
||||
sphinx_gallery_conf = {
|
||||
'examples_dirs': '../python/tutorials/',
|
||||
'gallery_dirs': 'getting-started/tutorials',
|
||||
'filename_pattern': '',
|
||||
'ignore_pattern': r'__init__\.py',
|
||||
'within_subsection_order': FileNameSortKey,
|
||||
}
|
||||
# extensions += ['sphinx_gallery.gen_gallery']
|
||||
# from sphinx_gallery.sorting import FileNameSortKey
|
||||
# sphinx_gallery_conf = {
|
||||
# 'examples_dirs': '../python/tutorials/',
|
||||
# 'gallery_dirs': 'getting-started/tutorials',
|
||||
# 'filename_pattern': '',
|
||||
# 'ignore_pattern': r'__init__\.py',
|
||||
# 'within_subsection_order': FileNameSortKey,
|
||||
# }
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
@@ -107,6 +109,9 @@ html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_css_files = [
|
||||
'css/custom.css',
|
||||
]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
@@ -164,8 +169,5 @@ man_pages = [(master_doc, 'triton', 'Triton Documentation', [author], 1)]
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(
|
||||
master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.',
|
||||
'Miscellaneous'
|
||||
),
|
||||
(master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.', 'Miscellaneous'),
|
||||
]
|
@@ -16,3 +16,16 @@ Getting Started
|
||||
|
||||
getting-started/installation
|
||||
getting-started/tutorials/index
|
||||
|
||||
Going Further
|
||||
--------------
|
||||
|
||||
- Check out the :doc:`programming guide <programming-guide/index>` to learn more about Triton and how it compares against other DSLs for DNNs.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Programming Guide
|
||||
:hidden:
|
||||
|
||||
programming-guide/introduction
|
||||
programming-guide/related-work
|
BIN
docs/programming-guide/cuda-parallel-matmul.png
Normal file
BIN
docs/programming-guide/cuda-parallel-matmul.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
BIN
docs/programming-guide/halide-iteration.png
Normal file
BIN
docs/programming-guide/halide-iteration.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
69
docs/programming-guide/introduction.rst
Normal file
69
docs/programming-guide/introduction.rst
Normal file
@@ -0,0 +1,69 @@
|
||||
==============
|
||||
Introduction
|
||||
==============
|
||||
|
||||
--------------
|
||||
Motivations
|
||||
--------------
|
||||
|
||||
Over the past decade, Deep Neural Networks (DNNs) have emerged as an important class of Machine Learning (ML) models, capable of achieving state-of-the-art performance across many domains ranging from natural language processing [1]_ to computer vision [2]_ to computational neuroscience [3]_. The strength of these models lies in their hierarchical structure, composed of a sequence of parametric (e.g., convolutional) and non-parametric (e.g., rectified linearity) *layers*. This pattern, though notoriously computationally expensive, also generates a large amount of highly parallelizable work particularly well suited for multi- and many- core processors.
|
||||
|
||||
As a consequence, Graphics Processing Units (GPUs) have become a cheap and accessible resource for exploring and/or deploying novel research ideas in the field. This trend has been accelerated by the release of several frameworks for General-Purpose GPU (GPGPU) computing, such as CUDA and OpenCL, which have made the development of high-performance programs easier. Yet, GPUs remain incredibly challenging to optimize for locality and parallelism, especially for computations that cannot be efficiently implemented using a combination of pre-existing optimized primitives. To make matters worse, GPU architectures are also rapidly evolving and specializing, as evidenced by the addition of tensor cores to NVIDIA (and more recently AMD) micro-architectures.
|
||||
|
||||
This tension between the computational opportunities offered by DNNs and the practical difficulty of GPU programming has created substantial academic and industrial interest for Domain-Specific Languages (DSLs) and compilers. Regrettably, these systems -- whether they be based on polyhedral machinery (*e.g.*, Tiramisu [4]_, Tensor Comprehensions [5]_) or scheduling languages (*e.g.*, Halide [6]_, TVM [7]_) -- remain less flexible and (for the same algorithm) markedly slower than the best handwritten compute kernels available in libraries like `cuBLAS <https://docs.nvidia.com/cuda/cublas/index.html>`_, `cuDNN <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>`_ or `TensorRT <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html>`_.
|
||||
|
||||
The main premise of this project is the following: programming paradigms based on blocked algorithms [8]_ can facilitate the construction of high-performance compute kernels for neural networks. We specifically revisit traditional "Single Program, Multiple Data" (SPMD [9]_) execution models for GPUs, and propose a variant in which programs -- rather than threads -- are blocked. For example, in the case of matrix multiplication, CUDA and Triton differ as follows:
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| CUDA Programming Model | Triton Programming Model |
|
||||
| | |
|
||||
| (Scalar Program, Blocked Threads) | (Blocked Program, Scalar Threads) |
|
||||
+=====================================================+=====================================================+
|
||||
| | |
|
||||
|.. code-block:: C |.. code-block:: C |
|
||||
| | :force: |
|
||||
| | |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int m = 0; i < M; m++) | for(int m = 0; m < M; m += MB) |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int n = 0; j < N; n++){ | for(int n = 0; n < N; n += NB){ |
|
||||
| float acc = 0; | float acc[MB, NB] = 0; |
|
||||
| for(int k = 0; k < K;k ++) | for(int k = 0; k < K; k += KB) |
|
||||
| acc += A[i, k]* B[k, j]; | acc += A[m:m+MB, k:k+KB] |
|
||||
| | @ B[k:k+KB, n:n+NB]; |
|
||||
| C[i, j] = acc; | C[m:m+MB, n:n+NB] = acc; |
|
||||
| } | } |
|
||||
| | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| |pic1| | |pic2| |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
|
||||
.. |pic1| image:: cuda-parallel-matmul.png
|
||||
|
||||
.. |pic2| image:: triton-parallel-matmul.png
|
||||
|
||||
A key benefit of this approach is that it leads to block-structured iteration spaces that offer programmers more flexibility than existing DSLs when implementing sparse operations, all while allowing compilers to aggressively optimize programs for data locality and parallelism.
|
||||
|
||||
--------------
|
||||
Challenges
|
||||
--------------
|
||||
|
||||
The main challenge posed by our proposed paradigm is that of work scheduling, i.e., how the work done by each program instance should be partitioned for efficient execution on modern GPUs. To address this issue, the Triton compiler makes heavy use of *block-level data-flow analysis*, a technique for scheduling iteration blocks statically based on the control- and data-flow structure of the target program. The resulting system actually works surprisingly well: our compiler manages to apply a broad range of interesting optimization automatically (e.g., automatic coalescing, thread swizzling, pre-fetching, automatic vectorization, tensor core-aware instruction selection, shared memory allocation/synchronization, asynchronous copy scheduling). Of course doing all this is not trivial; one of the purposes of this guide is to give you a sense of how it works.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [1] Sutskever et al., "Sequence to Sequence Learning with Neural Networks", NIPS 2014
|
||||
.. [2] Redmon et al., "You Only Look Once: Unified, Real-Time Object Detection", CVPR 2016
|
||||
.. [3] Lee et al., "Superhuman Accuracy on the SNEMI3D Connectomics Challenge", ArXiV 2017
|
||||
.. [4] Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [5] Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [6] Ragan-Kelley et al., "Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines", PLDI 2013
|
||||
.. [7] Chen et al., "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning", OSDI 2018
|
||||
.. [8] Lam et al., "The Cache Performance and Optimizations of Blocked Algorithms", ASPLOS 1991
|
||||
.. [9] Auguin et al., "Opsila: an advanced SIMD for numerical analysis and signal processing", EUROMICRO 1983
|
BIN
docs/programming-guide/polyhedral-iteration.png
Normal file
BIN
docs/programming-guide/polyhedral-iteration.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 59 KiB |
209
docs/programming-guide/related-work.rst
Normal file
209
docs/programming-guide/related-work.rst
Normal file
@@ -0,0 +1,209 @@
|
||||
==============
|
||||
Related Work
|
||||
==============
|
||||
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlights its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
|
||||
-----------------------
|
||||
Polyhedral Compilation
|
||||
-----------------------
|
||||
|
||||
Traditional compilers typically rely on intermediate representations, such as LLVM-IR [1]_, that encode control flow information using (un)conditional branches. This relatively low-level format makes it difficult to statically analyze the runtime behavior (e.g., cache misses) of input programs, and to automatically optimize loops accordingly through the use of tiling [2]_, fusion [3]_ and interchange [4]_. To solve this issue, polyhedral compilers [5]_ rely on program representations that have statically predictable control flow, thereby enabling aggressive compile-time program transformations for data locality and parallelism. Though this strategy has been adopted by many languages and compilers for DNNs such as Tiramisu [6]_, Tensor Comprehensions [7]_, Diesel [8]_ and the Affine dialect in MLIR [9]_, it also comes with a number of limitations that will be described later.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic1| |
|
||||
| | |
|
||||
| for(int i = 0; i < 3; i++) | |
|
||||
| for(int j = i; j < 5; j++) | |
|
||||
| A[i][j] = 0; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic1| image:: polyhedral-iteration.png
|
||||
:width: 300
|
||||
|
||||
Polyhedral compilers focus on a class of programs commonly known as **Static Control Parts** (SCoP), *i.e.*, maximal sets of consecutive statements in which conditionals and loop bounds are affine functions of surrounding loop indices and global invariant parameters. As shown above, programs in this format always lead to iteration domains that are bounded by affine inequalities, i.e., polyhedral. These polyhedra can also be defined algebraically; for the above example:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathcal{P} = \{ i, j \in \mathbb{Z}^2
|
||||
~|~
|
||||
\begin{pmatrix}
|
||||
1 & 0 \\
|
||||
-1 & 0 \\
|
||||
-1 & 1 \\
|
||||
0 & -1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i \\
|
||||
j
|
||||
\end{pmatrix}
|
||||
+
|
||||
\begin{pmatrix}
|
||||
0 \\
|
||||
2 \\
|
||||
0 \\
|
||||
4
|
||||
\end{pmatrix}
|
||||
\geq
|
||||
0
|
||||
\}
|
||||
|
||||
|
||||
Each point :math:`(i, j)` in :math:`\mathcal{P}` represents a *polyhedral statement*, that is a program statement which (1) does not induce control-flow side effects (e.g., :code:`for`, :code:`if`, :code:`break`) and (2) contains only affine functions of loop indices and global parameters in array accesses. To facilitate alias analysis, array accesses are also mathematically abstracted, using so-called *access function*. In other words, :code:`A[i][j]` is simply :code:`A[f(i,j)]` where the access function :math:`f` is defined by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(i, j) = \begin{pmatrix}
|
||||
1 & 0\\
|
||||
0 & 1\\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i\\
|
||||
j
|
||||
\end{pmatrix}
|
||||
=
|
||||
(i, j)
|
||||
|
||||
|
||||
Note that the iteration domains of an SCoP does not specify the order in which its statements shall execute. In fact, this iteration domain may be traversed in many different possible legal orders, i.e. *schedules*. Formally, a schedule is defined as a p-dimensional affine transformation :math:`\Theta` of loop indices :math:`\mathbf{x}` and global invariant parameters :math:`\mathbf{g}`:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = T_S \begin{pmatrix}
|
||||
\vec{x}\\
|
||||
\vec{g}\\
|
||||
1
|
||||
\end{pmatrix}
|
||||
\qquad
|
||||
T_S \in \mathbb{Z} ^{p \times (\text{dim}(\mathbf{x}) + \text{dim}(\mathbf{g}) + 1)}
|
||||
|
||||
|
||||
Where :math:`\Theta_S(\mathbf{x})` is a p-dimensional vector representing the slowest to fastest growing indices (from left to right) when traversing the loop nest surrounding :math:`S`. For the code shown above, the original schedule defined by the loop nest in C can be retrieved by using:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = \begin{pmatrix}
|
||||
1 & 0 \\
|
||||
0 & 1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
=
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
|
||||
|
||||
where :math:`i` and :math:`j` are respectively the slowest and fastest growing loop indices in the nest. If :math:`T_S` is a vector (resp. tensor), then :math:`\Theta_S` is a said to be one-dimensional (resp. multi-dimensional).
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
Programs amenable to polyhedral compilation can be aggressively transformed and optimized. Most of these transformations actually boil down to the production of schedules and iteration domains that enable loop transformations promoting parallelism and spatial/temporal data locality (e.g., fusion, interchange, tiling, parallelization).
|
||||
|
||||
Polyhedral compilers can also automatically go through complex verification processes to ensure that the semantics of their input program is preserved throughout this optimization phase. Note that polyhedral optimizers are not incompatible with more standard optimization techniques. In fact, it is not uncommon for these systems to be implemented as a set of LLVM passes that can be run ahead of more traditional compilation techniques [10]_.
|
||||
|
||||
All in all, polyhedral machinery is extremely powerful, when applicable. It has been shown to support most common loop transformations, and has indeed achieved performance comparable to state-of-the-art GPU libraries for dense matrix multiplication [8]_. Additionally, it is also fully automatic and doesn't require any hint from programmers apart from source-code in a C-like format.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
|
||||
|
||||
First, the set of possible program transformations $\Omega = \{ \Theta_S ~|~ S \in \text{program} \}$ is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [11]_.
|
||||
|
||||
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [12]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
|
||||
|
||||
On the other hand, blocked program representations advocated by this dissertation are less restricted in scope and can achieve close to peak performance using standard dataflow analysis.
|
||||
|
||||
-----------------------
|
||||
Scheduling Languages
|
||||
-----------------------
|
||||
|
||||
Separation of concerns \cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
|
||||
// algorithm
|
||||
Var x("x"), y("y");
|
||||
Func matmul("matmul");
|
||||
RDom k(0, matrix_size);
|
||||
RVar ki;
|
||||
matmul(x, y) = 0.0f;
|
||||
matmul(x, y) += A(k, y) * B(x, k);
|
||||
// schedule
|
||||
Var xi("xi"), xo("xo"), yo("yo"), yi("yo"), yii("yii"), xii("xii");
|
||||
matmul.vectorize(x, 8);
|
||||
matmul.update(0)
|
||||
.split(x, x, xi, block_size).split(xi, xi, xii, 8)
|
||||
.split(y, y, yi, block_size).split(yi, yi, yii, 4)
|
||||
.split(k, k, ki, block_size)
|
||||
.reorder(xii, yii, xi, ki, yi, k, x, y)
|
||||
.parallel(y).vectorize(xii).unroll(xi).unroll(yii);
|
||||
|
||||
|
||||
The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms [13]_.
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
The main advantage of this approach is that it allows programmers to write an algorithm *only once*, and focus on performance optimization separately. It makes it possible to manually specify optimizations that a polyhedral compiler wouldn't be able to figure out automatically using static data-flow analysis.
|
||||
|
||||
Scheduling languages are, without a doubt, one of the most popular approaches for neural network code generation. The most popular system for this purpose is probably TVM, which provides good performance across a wide range of platforms as well as built-in automatic scheduling mechanisms.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse com-putations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic2| |
|
||||
| | |
|
||||
| for(int i = 0; i < 4; i++) | |
|
||||
| for(int j = 0; j < 4; j++) | |
|
||||
| float acc = 0; | |
|
||||
| for(int k = 0; k < K[i]; k++) | |
|
||||
| acc += A[i][col[i,k]]*B[k][j] | |
|
||||
| C[i][j] = acc; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic2| image:: halide-iteration.png
|
||||
:width: 300
|
||||
|
||||
On the other hand, the block-based program representation that we advocate for through this work allows for block-structured iteration spaces and allows programmers to manually handle load-balancing as they wish.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [1] Lattner et al., "LLVM: a compilation framework for lifelong program analysis transformation"
|
||||
.. [2] Wolfe, "More Iteration Space Tiling", SC 1989
|
||||
.. [3] Darte, "On the Complexity of Loop Fusion", PACT 1999
|
||||
.. [4] Allen et al., "Automatic Loop Interchange", SIGPLAN Notices 1984
|
||||
.. [5] Ancourt et al., "Scanning Polyhedra with DO Loops", PPoPP 1991
|
||||
.. [6] Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [7] Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [8] Elango et al. "Diesel: DSL for Linear Algebra and Neural Net Computations on GPUs", MAPL 2018
|
||||
.. [9] Lattner et al., "MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law", Arxiv 2019
|
||||
.. [10] Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
|
||||
.. [11] Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
|
||||
.. [12] Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
|
||||
.. [13] Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
|
83
docs/programming-guide/triton-c.rst
Normal file
83
docs/programming-guide/triton-c.rst
Normal file
@@ -0,0 +1,83 @@
|
||||
=======================
|
||||
The Triton-C Language
|
||||
=======================
|
||||
|
||||
In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model.
|
||||
|
||||
-------------------
|
||||
Differences with C
|
||||
-------------------
|
||||
|
||||
The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories:
|
||||
|
||||
+++++++++++
|
||||
Extensions
|
||||
+++++++++++
|
||||
|
||||
**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`).
|
||||
|
||||
**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`.
|
||||
|
||||
**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`).
|
||||
|
||||
**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`.
|
||||
|
||||
**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// create mask
|
||||
bool mask[16, 16] = ...;
|
||||
// conditional addition
|
||||
float x[16, 16] = mask ? a + b : 0;
|
||||
// conditional load
|
||||
float y[16] 16] = mask ? *ptr : 0;
|
||||
// conditional store
|
||||
*?(mask)ptr = y;
|
||||
\end{lstlisting}
|
||||
|
||||
|
||||
+++++++++++++
|
||||
Restrictions
|
||||
+++++++++++++
|
||||
|
||||
The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported:
|
||||
|
||||
**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments.
|
||||
|
||||
**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes.
|
||||
|
||||
**While loops**: We just haven't had time to implement those yet.
|
||||
|
||||
----------------
|
||||
Semantics
|
||||
----------------
|
||||
|
||||
The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Broadcasting Semantics
|
||||
+++++++++++++++++++++++
|
||||
|
||||
|
||||
Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done.
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
int a[16], b[32, 16], c[16, 1];
|
||||
// a is first reshaped to [1, 16]
|
||||
// and then broadcast to [32, 16]
|
||||
int x_1[32, 16] = a[newaxis, :] + b;
|
||||
// Same as above but implicitly
|
||||
int x_2[32, 16] = a + b;
|
||||
// a is first reshaped to [1, 16]
|
||||
// a is broadcast to [16, 16]
|
||||
// c is broadcast to [16, 16]
|
||||
int y[16, 16] = a + c;
|
||||
|
||||
------------------
|
||||
Programming Model
|
||||
------------------
|
||||
|
||||
As discussed in the `CUDA documentation <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_, The execution of CUDA code on GPUs is supported by an `SPMD <https://en.wikipedia.org/wiki/SPMD>`_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials.
|
||||
|
BIN
docs/programming-guide/triton-parallel-matmul.png
Normal file
BIN
docs/programming-guide/triton-parallel-matmul.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 KiB |
Reference in New Issue
Block a user