[PYTHON][TESTS][DOC] Various improvement of the API and code quality:

* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
Philippe Tillet
2021-01-29 17:27:16 -05:00
parent a5a477c36b
commit 269ebc12e5
63 changed files with 2255 additions and 3883 deletions

View File

@@ -4,7 +4,7 @@ include(CTest)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
# Options # Options
option(BUILD_TESTS "Build C++ Triton tests" ON) option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# LLVM # LLVM
@@ -25,16 +25,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=defa
# Tests # Tests
if(BUILD_TESTS) if(BUILD_TESTS)
message(STATUS "Adding C++ tests") message(STATUS "Adding C++ tutorials")
add_subdirectory(tests) add_subdirectory(tutorials)
endif() endif()
# Python module # Python module
if(BUILD_PYTHON_MODULE) if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module") message(STATUS "Adding Python module")
# PyBind11 wrapper source file # PyBind11 wrapper source file
set(PYTHON_SRC bindings.cc launch.cc) set(TORCH_SRC torch/launch.cc torch/superblock.cc)
set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") set(PYTHON_SRC bindings.cc ${TORCH_SRC})
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
include_directories("." ${PYTHON_INCLUDE_DIRS}) include_directories("." ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS}) link_directories(${PYTHON_LINK_DIRS})
endif() endif()

View File

@@ -1,4 +1,4 @@
/* Copyright 2018-2020 Philippe Tillet /* Copyright 2018-2021 Philippe Tillet
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files * a copy of this software and associated documentation files

View File

@@ -2,20 +2,11 @@
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The main components of Triton at the moment are: The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level (think numpy-like array operations in a C-like language).
- **Triton-IR**: A special-purpose intermediate representation (Triton-IR) for aiding array-level program analysis and optimizations in Triton-C programs.
- **Triton-JIT**: An optimizing just-in-time compiler for Triton-IR, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)). This includes transparent support for mixed-precision and Tensor Cores.
Bindings for **automatic** PyTorch custom op generations are included in **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc.
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
## Installation ## Installation
Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM-10+ for code generation.
You can install the latest release with pip as follows: You can install the latest release with pip as follows:
``` ```
sudo apt-get install llvm-10-dev sudo apt-get install llvm-10-dev
@@ -30,6 +21,7 @@ or the latest development version with:
for the C++ package: for the C++ package:
``` ```
git clone https://github.com/ptillet/triton.git; git clone https://github.com/ptillet/triton.git;
cd triton;
mkdir build; mkdir build;
cd build; cd build;
cmake ../; cmake ../;
@@ -39,9 +31,4 @@ make -j8;
## Getting Started ## Getting Started
Please visit the [documentation](https://docs.triton-lang.org) to get started with Triton You can find tutorials for Triton for [Python](https://github.com/ptillet/triton/tree/master/tutorials) and [C++](https://github.com/ptillet/triton/tree/master/python/tutorials).
## Contributing
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!

View File

@@ -11,7 +11,6 @@ Welcome to Triton's documentation!
:caption: Contents: :caption: Contents:
installation/index installation/index
tutorials/index
Indices and tables Indices and tables

View File

@@ -1,102 +0,0 @@
===========================
Writing a Custom Operation
===========================
--------------
Compute Kernel
--------------
Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following:
.. code-block:: C
// Triton
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
// bounds checking
bool check[TILE] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch.
---------------
PyTorch Wrapper
---------------
As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code.
.. code-block:: python
import torch
import triton
class _add(torch.autograd.Function):
# source-code for Triton compute kernel
src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
// bounds checking
bool check[TILE] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# create callable kernel for the source-code
# options: 4 warps and a -DTILE=1024
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
# Forward pass
@staticmethod
def forward(ctx, x, y):
# type checking
assert x.dtype == torch.float32
# allocate output
z = torch.empty_like(x).cuda()
# create launch grid
# this is a function of the launch parameters
# triton.cdiv indicates ceil division
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
# launch kernel
_add.kernel(z, x, y, N, grid = grid)
# return output
return z
# get callable from Triton function
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(98432).cuda()
y = torch.rand(98432).cuda()
za = x + y
zb = add(x, y)
diff = (za - zb).abs().max()
print(diff)
print(torch.allclose(za,zb))
Executing the above code will:
- Generate a .cpp file containing PyTorch bindings for the Triton function
- Compile this .cpp file using distutils
- Cache the resulting custom op
- Call the resulting custom op
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.

View File

@@ -1,10 +0,0 @@
Tutorials
==========
.. toctree::
:maxdepth: 1
custom-operation
triton-vs-cuda
matrix-transposition
matrix-multiplication

View File

@@ -1,186 +0,0 @@
*********************
Matrix Multiplication
*********************
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see `CUTLASS <https://github.com/NVIDIA/cutlass>`_). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
*Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops.*
==============
Compute Kernel
==============
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
.. code-block:: C
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0); //(1)
int pn = get_program_id(1); //(2)
int rm[TM] = pm * TM + 0 ... TM; //(3)
int rn[TN] = pn * TN + 0 ... TN; //(4)
int rk[TK] = 0 ... TK; //(5)
// initialize accumulator
float c[TM, TN] = 0; //(6)
// pointers to operands
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
// reduction loop
for(int k = K; k > 0; k-= TK){
// fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
}
// epilogue
TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
*pc = c; //(15)
}
Here, each kernel instance produces a :code:`TM x TN` tile of the output matrix C as follows:
- Statements (1) - (2) fetch the id of the current program instance.
- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix :code:`C`
- Statement (5) constructs a range of indices along the reduction axis: :code:`rk = [0, 1, ..., TK - 1]`
- Statement (6) initialize a :code:`TM x TN` array of accumulators to hold the result of :code:`A[rm, :] x B[:, rn]`
- Statements (7) - (8) initializes arrays of pointers :code:`pa` and :code:`pb` to the operands :code:`A` and :code:`B` using logic similar to that of the above transposition kernel
- Statements (9) - (10) load tiles of operands by dereferencing :code:`pa` and :code:`pb`
- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator :code:'@'
- Statements (12) - (13) updates :code:`pa` and :code:`pb`
- Statement (14) creates an array of pointers `pc` to the result matrix :code:`C`
- Statement (15) writes back the accumulator to :code:`C`
Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
- Automatic coalescing of load/store operations
- Automatic vectorization of load/store operations
- Stashing `a` and `b` to shared memory
- Automatic allocation of shared memory
- Automatic synchronization of shared memory
- Automatic padding of shared memory to avoid bank conflicts
- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
==============
Optimizations
==============
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
-------------
Pre-Fetching
-------------
The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
.. code-block:: C
// pre-fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
bool check = k > TK;
// pre-fetch operands
a = check ? *pa : 0;
b = check ? *pb : 0;
}
Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
-----------------
Rematerialization
-----------------
`Rematerialization <https://en.wikipedia.org/wiki/Rematerialization>`_ is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only :code:`rm` and :code:`rn` need to be rematerialized, leading to the following epilogue:
.. code-block:: C
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
------------------------------------
Fused Transpositions and Auto-Tuning
------------------------------------
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
.. code-block:: C
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
All matrix multiplications variants can then be retrieved using the following compilation option:
.. code-block:: C
// A is not transposed
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
// A is transposed
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
// B is not transpose
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
// B is transpose
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
Auto-tuning can also be handled using pre-processor macros:
.. code-block:: C
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.

View File

@@ -1,174 +0,0 @@
*********************
Matrix Transpositions
*********************
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
==============
Compute Kernel
==============
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like:
.. code-block:: C
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
__global__ void transpose(TYPE * X, TYPE * Y,
int M, int N, int ldx, int ldy) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
// write back using the transposition operator '^'
*py = ^(*px); //(7)
}
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
.. code-block:: C
pidm = 2
pidn = 1
- Statements (3) and (4) construct the ranges of indices:
.. code-block:: C
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
::
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
::
│ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
│ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
==================================
A Note on Numpy-style Broadcasting
==================================
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
.. code-block:: python
import numpy as np
vec = np.linspace(0,11,12)
mat = vec.reshape((4,3))
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
::
[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
::
[ 0, 1, 2],
[ 3, 4, 5],
We construct ranges `rm` and `rn` as:
::
rm = [0, 1]
rn = [0, 1, 2]
Using numpy-style broadcasting, we can add these together to create a matrix:
::
rm[:, np.newaxis] + rn[np.newaxis, :]
rn -> [0, 1, 2]
rm -> [0., [[0, 1, 2],
1.] [1, 2, 3]]
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
In this case this is 3, so what we really need is:
::
ldx = 3
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
==========================
The __multipleof attribute
==========================
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
.. code-block:: C
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
int ldy __multipleof(8)) {
// ...
}
==========================
Bounds Checking
==========================
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
.. code-block:: C
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
// ...
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.

View File

@@ -1,180 +0,0 @@
====================================================
Putting It All Together
====================================================
In the previous tutorial, we saw how to write tensor-core-friendly matrix multiplication code competitive with cuBLAS in 20 lines of Triton code. Here, we will see how to wrap it into an automatically differentiable PyTorch functions for easy integration in your Deep Learning pipeline.
-----------------
PyTriton Function
-----------------
The PyTriton API provides a :code:`triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
.. code-block:: python
import triton
# Entry point
class _dot(torch.autograd.Function):
@staticmethod
# Forward Pass
def forward(ctx, *args):
#...
@staticmethod
# Backward Pass
def backward(ctx, dy):
#...
-----------------
PyTriton Kernels
-----------------
PyTriton also provides a :code:`triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton code shown at the end of the previous tutorial.
.. code-block:: python
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
"""
kernel = triton.kernel(src)
At this point, `kernel` is a callable object which takes the same signature as the :code:`dot` function in our source code, except that pointers are treated as tensors: :code:`[tensor, tensor, tensor, int, int, int, int, int, int]`.
-----------------------
Using PyTriton Kernels
-----------------------
However, in practice only A, B are provided by the user, and all the other :code:`int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the :code:`A` and :code:`B` tensors, and then returns the results of a call to :code:`kernel`:
.. code:: python
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = a.shape
shape_b = b.shape
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# launch grid
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# pre-processor definitions
defines = {# tile sizes
'TYPE' : dtype,
'AT' : transpose_a,
'BT' : transpose_b,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [8],
# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
grid=grid, num_warps=4, defines=defines)
--------------------------------------------
Automatic Differentiation
--------------------------------------------
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
Creating custom operations for Triton and PyTorch is very similar; programmers have to provide two static methods :code:`forward` and :code:`backward` that take a context as their first input:
.. code:: python
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
else:
assert False
return da, db, None, None, None, None, None, None, None
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
.. code:: python
dot = _dot.apply
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~

View File

@@ -1,109 +0,0 @@
***************
Triton vs. CUDA
***************
The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code:
.. code-block:: C
// Triton
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float* z, float* x, float* y, int N){
int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE;
bool check[TILE] = offset < N;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
*?(check)pz = *?(check)*px + *?(check)py;
}
And its CUDA equivalent:
.. code-block:: C
// CUDA
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float *z, float *x, float *y, int N) {
int off = blockIdx.x * TILE + threadIdx.x;
if(off < N){
float *pz = z + off;
float *px = x + off;
float *py = y + off;
*pz = *px + *py
}
}
==========================
Automatic parallelization
==========================
While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**.
This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations:
- *Automatic shared memory allocation and synchronization*
That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts.
- *Automatic memory coalescing*
When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible.
- *Automatic tensor core utilization*
Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together:
.. code-block:: C
half A[16, 8] = ... // initialize A
half B[8, 16] = ... // initialize B
float C[16, 16] = dot(A, B); // uses Tensor Cores!
- *Automatic instruction predication*
Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals.
===========================
Vector Addition - Revisited
===========================
In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following:
.. code-block:: C
// CUDA
// launch on a grid of (N + TILE - 1) / TILE programs
__global__ void add(float *z, float *x, float *y, int N) {
int off[4];
#pragma unroll
for(int k = 0; k < 4; k++)
off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x;
#pragma unroll
for(int k = 0; k < 4; k++)
z[off[0]] = x[off[0]] + y[off[0]]
}
This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread.
============================
Auto-Tuning
============================
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
.. code-block:: python
kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8])
would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
=============================
Going Further
=============================
The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture!

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend); static driver::stream* create(backend_t backend);
// methods // methods
virtual void synchronize() = 0; virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args = NULL, size_t args_size = 0) = 0; virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers // template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public: public:
host_stream(); host_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership); cu_stream(CUstream str, bool take_ownership);
cu_stream(); cu_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };

View File

@@ -4,27 +4,19 @@
#define _TRITON_RUNTIME_FUNCTION_H_ #define _TRITON_RUNTIME_FUNCTION_H_
#include <map> #include <map>
#include <unordered_map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <set>
// codegen // codegen
#include "triton/ir/context.h" #include "triton/ir/context.h"
#include "triton/codegen/target.h"
#include "triton/runtime/arg.h" #include "triton/runtime/arg.h"
#include "triton/runtime/error.h" #include "triton/runtime/error.h"
namespace llvm { // driver forward declaration
class Module;
class LLVMContext;
}
class Parser;
namespace triton { namespace triton {
namespace driver{ namespace driver{
class module; class module;
class stream; class stream;
@@ -32,26 +24,19 @@ namespace driver{
class context; class context;
class device; class device;
} }
namespace lang{
class translation_unit;
} }
// ir forward declaration
namespace codegen{ namespace triton{
namespace analysis{
class tiles;
}
}
namespace ir { namespace ir {
class module; class module;
class function; class function;
class context; class context;
} }
}
namespace triton{
namespace runtime{ namespace runtime{
typedef std::vector<size_t> grid_t; typedef std::vector<size_t> grid_t;
typedef std::map<std::string, size_t> params_t; typedef std::map<std::string, size_t> params_t;
template<typename T> inline T convert(const std::string& name); template<typename T> inline T convert(const std::string& name);
@@ -73,7 +58,6 @@ struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t; typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines; std::vector<define_t> defines;
std::vector<int> num_warps; std::vector<int> num_warps;
std::vector<int> recompile_key;
}; };
struct options_t { struct options_t {
@@ -81,88 +65,69 @@ struct options_t {
T D(const std::string& name) const { T D(const std::string& name) const {
return convert<T>(defines.at(name)); return convert<T>(defines.at(name));
} }
bool operator<(const options_t& other) const { std::unordered_map<std::string, std::string> defines;
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps; size_t num_warps;
}; };
/* ------------------------- */
class kernel{
private:
static std::string preheader();
static arg_type convert(ir::type *ty);
public:
kernel(const std::string& src, const options_t& opt, driver::device *device);
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
// getters
const std::vector<arg_type>& get_sig() const { return sig_; }
private:
void init_ir (const std::string &src);
void init_ker();
void init_sig();
public:
const options_t opt;
private:
driver::device* dev_;
// signature
std::vector<arg_type> sig_;
// triton context for parsing
ir::context ctx_;
// handles
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
};
class function { class function {
public: public:
typedef std::function<grid_t(const options_t&)> grid_fn_ty; typedef std::function<grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
private: private:
class caller { static void do_loop_nest(std::vector<size_t> const & ranges,
public: std::function<void(std::vector<size_t> const &)> const & f);
// constructors
caller(std::ifstream& ifs, const options_t& opt);
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
// serialization
void write(std::ofstream& ofs);
void read(std::ifstream& ifs);
// accessors
const options_t opt() const { return opt_; }
const driver::module* parent() const { return &*parent_; }
const driver::kernel* bin() const { return &*bin_; }
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
const std::vector<arg_type>& param_tys() const { return param_tys_; }
std::vector<int> retune() const { return retune_; }
// entry points
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map<std::string, std::vector<char>>& = {}) const;
private:
std::shared_ptr<driver::kernel> bin_;
std::shared_ptr<driver::module> parent_;
std::vector<arg_type> param_tys_;
std::vector<int> retune_;
options_t opt_;
std::string name_;
};
private:
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
private:
// cache
static std::string get_cache_prefix();
// make
triton::lang::translation_unit *make_ast(const std::string &src);
std::unique_ptr<ir::module> make_ir(Parser &parser);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::device *device, const options_t &opt);
void make(driver::device *device, options_t opt);
void precompile(driver::device *device, const options_space_t& tuning_space);
// autotune
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
public: public:
static std::string preheader(); function(const std::string& src, const options_space_t& opt, driver::device *device);
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
public: void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); // auto-tuning
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device); cache_t::iterator find_in_cache(void* args, size_t args_size);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device); kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void set_cst(const char* name, void* data, size_t n_bytes); // getters
std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt); const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
private: private:
std::map<std::string, std::vector<char>> cst_; void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
// pre-compilation
ir::context ctx_; private:
std::string src_; std::vector<kernel_pair_t> kernels_;
options_space_t opt_; std::map<std::vector<uint64_t>, kernel*> cache_;
std::set<options_t> compiled_;
std::map<options_t, std::unique_ptr<caller>> callers_;
std::vector<int> args_off_;
size_t args_size_;
// caching
std::string cache_ref_;
std::string cache_path_;
std::map<cache_key_t, caller*> cache_;
}; };
} }

View File

@@ -168,7 +168,6 @@ mma_layout::mma_layout(size_t num_warps,
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
pack_size_0 = pack_size_1 = 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1}; rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
} }

View File

@@ -76,7 +76,7 @@ void host_stream::synchronize() {
hst_->args.clear(); hst_->args.clear();
} }
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args, size_t args_size) { void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
auto hst = kernel->module()->hst(); auto hst = kernel->module()->hst();
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
char* params = new char[args_size]; char* params = new char[args_size];
@@ -113,7 +113,7 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_); dispatch::cuStreamSynchronize(*cu_);
} }
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void** args, size_t args_size) { void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
void *config[] = { void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,

View File

@@ -108,11 +108,9 @@ void Parser::ParseTranslationUnit() {
FuncDef* Parser::ParseFuncDef(Identifier* ident) { FuncDef* Parser::ParseFuncDef(Identifier* ident) {
auto funcDef = EnterFunc(ident); auto funcDef = EnterFunc(ident);
if (funcDef->FuncType()->Complete()) { if (funcDef->FuncType()->Complete()) {
Error(ident, "redefinition of '%s'", funcDef->Name().c_str()); Error(ident, "redefinition of '%s'", funcDef->Name().c_str());
} }
// TODO(wgtdkp): param checking // TODO(wgtdkp): param checking
auto funcType = ident->Type()->ToFunc(); auto funcType = ident->Type()->ToFunc();
funcType->SetComplete(true); funcType->SetComplete(true);

View File

@@ -1,4 +1,4 @@
#include <string> #include <string>
#include <mutex> #include <mutex>
#include <regex> #include <regex>
#include <functional> #include <functional>
@@ -45,360 +45,25 @@ std::mutex mut;
namespace triton{ namespace triton{
namespace runtime { namespace runtime {
/* --------------------- */ /* --------------------------------- */
/* HELPERS */ /* --------------------------------- */
/* --------------------- */ /* --------------------------------- */
void _loop_nest(std::vector<size_t> const & ranges, arg_type kernel::convert(ir::type *ty) {
std::function<void(std::vector<size_t> const &)> const & f){ if(ty->is_integer_ty(1)) return INT1_T;
size_t D = ranges.size(); if(ty->is_integer_ty(8)) return INT8_T;
std::vector<size_t> values(D, 0); if(ty->is_integer_ty(16)) return INT16_T;
size_t i = D - 1; if(ty->is_integer_ty(32)) return INT32_T;
while(true){ if(ty->is_integer_ty(64)) return INT64_T;
f(values); if(ty->is_half_ty()) return HALF_T;
while(values[i]++ == ranges[i] - 1){ if(ty->is_float_ty()) return FLOAT_T;
if(i == 0) if(ty->is_double_ty()) return DOUBLE_T;
return; if(ty->is_pointer_ty()) return BUFFER_T;
values[i--] = 0;
}
i = D - 1;
}
}
/* --------------------- */
/* OPTIONS */
/* --------------------- */
std::string options_t::to_str() const{
std::string ret = "nw-" + std::to_string(num_warps);
for(const auto& x : defines){
ret += '-';
ret += x.first;
ret += '-';
ret += x.second;
}
// legalize
for(char& x: ret){
if(x == ' ' || x == '^' || x == ',' || x == ':')
x = '_';
}
return ret;
}
/* --------------------- */
/* CALLER OBJECT */
/* --------------------- */
arg_type convert(ir::type *ty) {
if(ty->is_integer_ty(1))
return INT1_T;
if(ty->is_integer_ty(8))
return INT8_T;
if(ty->is_integer_ty(16))
return INT16_T;
if(ty->is_integer_ty(32))
return INT32_T;
if(ty->is_integer_ty(64))
return INT64_T;
if(ty->is_half_ty())
return HALF_T;
if(ty->is_float_ty())
return FLOAT_T;
if(ty->is_double_ty())
return DOUBLE_T;
if(ty->is_pointer_ty())
return BUFFER_T;
throw std::runtime_error("unknown type"); throw std::runtime_error("unknown type");
} }
//void function::caller::write(std::ofstream &ofs) {
// // write name
// ofs << name_ << std::endl;
// // write signature
// for(size_t i = 0; i < param_tys_.size(); i++)
// ofs << param_tys_[i] << " ";
// ofs << std::endl;
// // write module
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
// ofs << source;
//}
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) { std::string kernel::preheader() {
// // read name
// std::getline(ifs, name_);
// // read signature
// std::string line;
// std::getline(ifs, line);
// std::istringstream current(line);
// int param;
// param_tys_.clear();
// while(current >> param)
// param_tys_.push_back((arg_type)param);
// // read module
// std::string src((std::istreambuf_iterator<char>(ifs)),
// std::istreambuf_iterator<char>());
// parent_.reset(new driver::cu_module(ctx, src));
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
//}
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
// : opt_(opt) {
// read(ctx, ifs);
//}
function::caller::caller(ir::function *ir,
std::shared_ptr<driver::module> parent, const options_t& opt)
: parent_(parent), opt_(opt), name_(ir->get_name()) {
bin_.reset(driver::kernel::create(&*parent, name_.c_str()));
// extract signature
ir::function_type* ty = ir->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++){
param_tys_.push_back(convert(ty->get_param_ty(i)));
if(!ir->has_attr(i+1))
continue;
for(ir::attribute attr: ir->attrs().at(i + 1))
if(attr.get_kind() == ir::retune)
retune_.push_back(i);
}
}
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size, const std::map<std::string, std::vector<char>>& csts) const {
// copy constants
for(const auto& cst: csts){
std::unique_ptr<driver::buffer> buffer = parent()->symbol(cst.first.c_str());
stream->write(&*buffer, true, 0, cst.second);
}
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
std::array<size_t, 3> grid;
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, args, args_size);
}
/* --------------------- */
/* FUNCTION */
/* --------------------- */
// create Triton-IR from AST
std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
ir::module* module = new ir::module("", ctx_);
Generator gen(&parser);
gen.Gen(module);
return std::unique_ptr<ir::module>(module);
}
// create Binary from Triton-IR
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::device* device, const options_t& opt) {
std::unique_ptr<codegen::target> target = device->make_target();
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// optimizations
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get());
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(module);
disassociate.run(module);
dce.run(module);
peephole.run(module);
dce.run(module);
align.run(module);
if(target->is_gpu())
cts.run(module);
axes.run(module);
layouts.run(module);
coalesce.run(module);
dce.run(module);
align.run(module);
dce.run(module);
if(target->is_gpu()){
reassociate.run(module);
cts.run(module);
}
peephole.run(module);
dce.run(module);
align.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
swizzle.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > device->max_shared_memory())
throw exception::out_of_shared_memory();
barriers.run(module);
isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
// if(res->spilled() > 256)
// throw exception::out_of_registers();
return res;
}
// create Binary from options
void function::make(driver::device *device, options_t opt) {
if(callers_.find(opt) != callers_.end())
return;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&src_, true);
for(auto it: opt.defines)
cpp.AddMacro(it.first, &it.second);
cpp.Process(tokens);
// src -> ast
Parser parser(tokens);
parser.Parse();
// ast -> triton-ir
auto ir = make_ir(parser);
// triton-ir -> binary
std::unique_ptr<driver::module> bin;
try{
bin = make_bin(*ir, device, opt);
}catch(const exception::base&){
throw;
}
// create callable
ir::function *tmp = ir->get_function_list()[0];
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
}
// precompile all kernels spanned by given options space
void function::precompile(driver::device* device, const options_space_t& space) {
// all ranges
std::vector<size_t> ranges;
ranges.push_back(space.num_warps.size());
for(const auto& x: space.defines)
ranges.push_back(x.second.size());
// functor for source with given option
std::map<options_t, std::string> err;
auto do_make = [&](std::vector<size_t> params) {
// compilation options
unsigned i = 0;
options_t opt;
opt.num_warps = space.num_warps[params[i++]];
for(auto D: space.defines)
opt.defines[D.first] = D.second[params[i++]];
// compile
try{
make(device, opt);
}catch(const exception::base& e){
err[opt] = e.what();
}
};
// multi-threaded compilation
_loop_nest(ranges, do_make);
if(callers_.empty()){
std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
for(auto x: err){
dbg << "[ ";
dbg << x.first.num_warps << ", ";
dbg << "{ ";
for(const auto& y: x.first.defines)
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
dbg << " } ] -> " << x.second << std::endl;
}
throw exception::no_valid_configuration(dbg.str());
}
}
std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
make(device, opt);
const auto& fn = callers_.at(opt);
if(!fn)
return "";
switch(mode){
case ASM_LLIR:{
return fn->parent()->llir();
}
case ASM_NV_PTX:
case ASM_NV_SASS:{
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
// SASS
std::string input = std::tmpnam(nullptr);
std::string output = std::tmpnam(nullptr);
std::ofstream ofs(input);
ofs << ptx;
ofs.close();
if(mode == ASM_NV_PTX)
return ptx;
std::string cmd;
int err;
// compile ptx
driver::cu_device* cu_device = (driver::cu_device*)device;
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
err = system(cmd.c_str());
// disassemble
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
err = system(cmd.c_str());
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
std::string to_delete = " /*";
std::ifstream ifs(output);
std::string line;
std::string sass;
while(std::getline(ifs, line))
if(!std::regex_match(line, comment))
sass += line + "\n";
return sass;
}
default:
return "";
}
}
// returns program with best compilation options for given parameter
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
void** args, size_t args_size) {
// fast path -- no autotuning necessary
if(callers_.size() == 1)
return &*callers_.begin()->second;
// run auto-tuner
double best_ts = INFINITY;
caller* ret = nullptr;
for(auto &x : callers_){
if(x.second == nullptr)
throw std::runtime_error("configuration not compiled");
caller* current = &*x.second;
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size, cst_); },
stream, true);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
}
stream->synchronize();
return ret;
}
// set copy host buffer "data" into constant memory buffer "name"
void function::set_cst(const char* name, void* data, size_t n_bytes) {
cst_[std::string(name)] = std::vector<char>((char*)data, (char*)data + n_bytes);
}
std::string function::preheader() {
return R"( return R"(
#define bool _Bool #define bool _Bool
#define true 1 #define true 1
@@ -452,67 +117,212 @@ typedef long int64;
)"; )";
} }
std::string function::get_cache_prefix() { void kernel::init_ir(const std::string& src) {
//user-specified cache path // pre-process
std::string result = tools::getenv("TRITON_CACHE_PATH"); TokenSequence tokens;
if(!result.empty()){ Preprocessor cpp(&src, true);
if(tools::mkpath(result)==0) for(auto it: opt.defines)
return result; cpp.AddMacro(it.first, &it.second);
} cpp.Process(tokens);
//create in home // src -> ast
result = tools::getenv("HOME"); Parser parser(tokens);
if(!result.empty()) parser.Parse();
{ // ast -> triton-ir
result = result + "/.triton/cache/"; ir::module* module = new ir::module("", ctx_);
if(tools::mkpath(result)==0) Generator gen(&parser);
return result; gen.Gen(module);
} ir_.reset(module);
return "";
} }
function::function(const std::string &src, void kernel::init_ker(){
const options_space_t& opt, // triton-ir -> binary
const std::string &cache_ref): std::unique_ptr<driver::module> bin;
src_(src), opt_(opt), cache_ref_(cache_ref) { std::unique_ptr<codegen::target> target = dev_->make_target();
// hash source code // generate llvm code
unsigned char hash[20]; llvm::LLVMContext ctx;
sha1::calc((void*)src_.data(), src_.size(), hash); std::string name = ir_->get_function_list()[0]->get_name();
// create cache path std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
char _hex[40]; // optimizations
sha1::toHexString(hash, _hex); bool cts_use_async = target->as_nvidia()->sm() >= 80;
std::string hex(_hex, _hex + 40); // create passes
cache_path_ = get_cache_prefix() + hex + "/"; codegen::analysis::align align;
tools::mkpath(cache_path_); codegen::analysis::axes axes;
// append pre-header to source codegen::transform::cts cts(cts_use_async);
src_ = preheader() + src_; codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get());
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(*ir_);
disassociate.run(*ir_);
dce.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
if(target->is_gpu())
cts.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
coalesce.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
dce.run(*ir_);
if(target->is_gpu()){
reassociate.run(*ir_);
cts.run(*ir_);
}
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
swizzle.run(*ir_);
liveness.run(*ir_);
allocation.run(*ir_);
if(allocation.allocated_size() > dev_->max_shared_memory())
throw exception::out_of_shared_memory();
barriers.run(*ir_);
isel.visit(*ir_, *llvm);
//if(res->spilled() > 256)
// throw exception::out_of_registers();
mod_.reset(driver::module::create(dev_, std::move(llvm)));
ker_.reset(driver::kernel::create(&*mod_, name.c_str()));
} }
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream, driver::device *device) { void kernel::init_sig() {
// pre-compile kernels ir::function* fn = ir_->get_function_list()[0];
if(callers_.empty()){ ir::function_type* ty = fn->get_fn_type();
precompile(device, opt_); for(size_t i = 0; i < ty->get_num_params(); i++){
sig_.push_back(convert(ty->get_param_ty(i)));
if(!fn->has_attr(i+1))
continue;
} }
// re-tuning key }
cache_key_t key;
key.first = device; kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev):
key.second = callers_.begin()->second->retune(); opt(opt), dev_(dev) {
// auto-tune if necessary init_ir(preheader() + src);
init_ker();
init_sig();
}
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
std::array<size_t, 3> grid;
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size);
}
/* --------------------------------- */
/* --------------------------------- */
/* --------------------------------- */
void function::do_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
size_t i = D - 1;
while(true){
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) {
// all ranges
std::vector<size_t> ranges;
ranges.push_back(opts.num_warps.size());
for(const auto& x: opts.defines)
ranges.push_back(x.second.size());
// functor for source with given option
std::vector<std::pair<options_t, std::string>> err;
auto do_make = [&](std::vector<size_t> params) {
// compilation options
unsigned i = 0;
options_t opt;
opt.num_warps = opts.num_warps[params[i++]];
for(auto D: opts.defines)
opt.defines[D.first] = D.second[params[i++]];
// compile
try{
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
}catch(const exception::base& e){
err.push_back({opt, e.what()});
}
};
// multi-threaded compilation
do_loop_nest(ranges, do_make);
if(kernels_.empty()){
std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
for(auto x: err){
dbg << "[ ";
dbg << x.first.num_warps << ", ";
dbg << "{ ";
for(const auto& y: x.first.defines)
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
dbg << " } ] -> " << x.second << std::endl;
}
throw exception::no_valid_configuration(dbg.str());
}
}
kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) {
// fast path -- no autotuning necessary
if(kernels_.size() == 1)
return &*kernels_.begin()->second;
// auto-tuning key
std::vector<uint64_t> key;
auto it = cache_.find(key); auto it = cache_.find(key);
if(it == cache_.end()){ if(it != cache_.end())
auto best = autotune(stream, grid_fn, args, args_size); return it->second;
it = cache_.insert({key, best}).first; // run auto-tuner
double best_ts = INFINITY;
kernel* ret = nullptr;
for(auto &x : kernels_){
kernel* current = &*x.second;
auto grid = grid_fn(x.first);
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, true);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
} }
// run stream->synchronize();
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size, cst_); it = cache_.insert({key, ret}).first;
return it->second;
} }
void function::operator()(void** args, function::function(const std::string& src, const options_space_t& opt, driver::device *device) {
size_t args_size, init_kernels(src, opt, device);
const grid_t& grid,
driver::stream* stream, driver::device *device) {
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream, device);
} }
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
runtime::kernel* fn = autotune(args, args_size, grid_fn, stream);
(*fn)(args, args_size, stream, grid_fn(fn->opt));
}
void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) {
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
}
} }

View File

@@ -1,39 +0,0 @@
import torch
import triton
class _add(torch.autograd.Function):
src = """
__global__ void add(float* z, float* x, float* y, int N) {
int pid = get_program_id(0);
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
bool check[TILE] = offset < N;
*pz = *px + *py;
}
"""
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(900).cuda()
y = torch.rand(900).cuda()
za = x + y
zb = add(x, y)
print(torch.allclose(za,zb))

View File

@@ -1,70 +0,0 @@
import torch
import triton
class _copy(torch.autograd.Function):
src = """
__global__ void copy(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
*py = *px;
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N;
dtype = x.dtype
y = torch.empty((M,N)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _copy.kernel is None:
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
_copy.kernel(x, y, M, N, ldx, grid=grid)
return y
copy = _copy.apply
# test
torch.manual_seed(0)
x = torch.randn(8,4).cuda()
print(x)
ya = x
yb = copy(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -1,143 +0,0 @@
import torch
import triton
class _dot(torch.autograd.Function):
src = """
#define STM 4
#define STN 4
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = M / TM;
int gridn = N / TN;
int stgridm = (gridm + STM - 1) / STM;
int stgridn = (gridn + STN - 1) / STN;
int stid = pid / (STM * STN);
int laneid = pid % (STM * STN);
int stm = stid / stgridn;
int stn = stid % stgridn;
int lanem = laneid / STN;
int lanen = laneid % STN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
"""
@staticmethod
def forward(ctx, a, b):
c = _dot._call(a,b)
return c
kernel = dict()
@staticmethod
def _call(a, b):
# create kernel if necessary
dtype = a.dtype
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [128],
'TN' : [128],
'TK' : [32],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
kernel = _dot.kernel[dtype]
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty([M,N], dtype=dtype, device=a.device)
print(kernel.asm('sass', c.device))
print(kernel.asm('ptx', c.device))
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
dot = _dot.apply
torch.manual_seed(0)
M, N, K = 4096, 4096, 4096
a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half()
#a[:] = 1
#b[:] = 1
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))

View File

@@ -1,76 +0,0 @@
import torch
import triton
class _transpose(torch.autograd.Function):
src = """
__global__ void transpose(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8), int ldy __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N
ldy = M
dtype = x.dtype
y = torch.empty((N,M)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _transpose.kernel is None:
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
return y
transpose = _transpose.apply
# test
torch.manual_seed(0)
x = torch.randn(1024,128).cuda()
print(x)
ya = torch.t(x)
yb = transpose(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -95,25 +95,18 @@ class CMakeBuild(build_ext):
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
find_llvm()
directories = [x[0] for x in os.walk(os.path.join('src', 'include'))]
data = []
for d in directories:
for htype in ['h', 'hpp']:
files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False)
data += [os.path.relpath(f, 'src') for f in files]
setup( setup(
name='triton', name='triton',
version='0.3.0', version='1.0.0',
author='Philippe Tillet', author='Philippe Tillet',
author_email='ptillet@g.harvard.edu', author_email='phil@openai.com',
description='A language and compiler for custom Deep Learning operations', description='A language and compiler for custom Deep Learning operations',
long_description='', long_description='',
packages=['triton', 'triton/_C'], packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch', 'sympy'], install_requires=['numpy', 'torch'],
package_data={'': data}, package_data={'triton/ops': ['*.c'],
'triton/ops/blocksparse': ['*.c']},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')], ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild), cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False, zip_safe=False,
@@ -122,7 +115,7 @@ setup(
url='https://github.com/ptillet/triton/', url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers 'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools', 'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license 'License :: OSI Approved :: MIT License', # Again, pick a license

View File

@@ -13,15 +13,19 @@
#include "triton/ir/function.h" #include "triton/ir/function.h"
using namespace triton; using namespace triton;
namespace rt = triton::runtime; namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t; typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
CUstream torch_get_cuda_stream(int64_t dev_id); std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
CUdevice torch_get_cuda_device(int64_t dev_id); std::map<int, std::shared_ptr<rt::function>> id_fn_map;
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
extern CUstream torch_get_cuda_stream(int64_t dev_id);
extern CUdevice torch_get_cuda_device(int64_t dev_id);
/* Grid utilities */ /* Grid utilities */
@@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) {
/* Function utilities */ /* Function utilities */
void register_fn(const map_key_t& key, void register_fn(int op_id,
int dev_id,
const std::string& src, const std::string& src,
const rt::options_space_t& opt) { const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end()) if(tt_devices.find(dev_id) == tt_devices.end()) {
id_fn_map[key].reset(new rt::function(src, opt, "")); driver::device* device;
driver::stream* stream;
if(dev_id >= 0){
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
}
else{
device = new triton::driver::host_device();
stream = new triton::driver::host_stream();
}
tt_devices[dev_id].reset(device);
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for(auto x: opt->defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
} }
void delete_fn(const map_key_t& key) { void delete_fn(int op_id) {
id_fn_map.erase(key); id_fn_map.erase(op_id);
} }
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->get_asm(mode, &device, opt);
}
void cleanup() { void cleanup() {
id_grid_map.clear(); id_grid_map.clear();
id_fn_map.clear(); id_fn_map.clear();
opt_cache_.clear();
} }
size_t make_op_id() { size_t make_op_id() {
return id_fn_map.size(); return id_fn_map.size();
} }
/* Function signature */ std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
void make_module(const std::string& src, ir::module* ir, return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
for(auto it: opt.defines){
cpp.AddMacro(it.first, &it.second[0]);
}
cpp.Process(tokens);
// parse
Parser parser(tokens);
parser.Parse();
Generator gen(&parser);
gen.Gen(ir);
} }
std::vector<rt::arg_type> get_fn_signature(const std::string& src, void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
const runtime::options_space_t& opt) { rt::function* fn = id_fn_map.at(op_id).get();
// triton-ir code-gen (*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx)); // for(size_t n = 0; n < constant_names.size(); n++){
make_module(src, &*ir, opt); // const torch::Tensor& x = constant_vals[n];
// function // fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
ir::function* fn = ir->get_function_list().front();
// extract signature
std::vector<rt::arg_type> ret;
ir::function_type* ty = fn->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++)
ret.push_back(rt::convert(ty->get_param_ty(i)));
return ret;
} }
typedef triton::runtime::options_t options_t; pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
typedef triton::runtime::options_space_t options_space_t; rt::function* fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t& opt){
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for(auto x: opt.defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t*>());
};
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
void init_superblocking(pybind11::module &m);
void init_launch(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) { PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API"; m.doc() = "Python bindings to the C++ Triton API";
// bindings for triton classes // bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type") pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T) .value("int1" , rt::INT1_T)
.value("int8", rt::INT8_T) .value("int8" , rt::INT8_T)
.value("int16", rt::INT16_T) .value("int16" , rt::INT16_T)
.value("int32", rt::INT32_T) .value("int32" , rt::INT32_T)
.value("int64", rt::INT64_T) .value("int64" , rt::INT64_T)
.value("half", rt::HALF_T) .value("half" , rt::HALF_T)
.value("float", rt::FLOAT_T) .value("float" , rt::FLOAT_T)
.value("double", rt::DOUBLE_T) .value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T); .value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode") pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX) .value("ptx" , rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS); .value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options") pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def(pybind11::init<>()) .def_readwrite("num_warps", &rt::options_t::num_warps)
.def("d", &options_t::D<int>) .def_readwrite("defines" , &rt::options_t::defines);
.def_readwrite("num_warps", &options_t::num_warps)
.def_readwrite("defines" , &options_t::defines);
pybind11::class_<options_space_t>(m, "options_space") pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("defines", &options_space_t::defines) .def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("num_warps", &options_space_t::num_warps); .def_readwrite("defines" , &rt::options_space_t::defines);
// hooks into triton constructs since frameworks may not use pybind11 // hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature); m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_asm", &get_fn_asm); // m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid); m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid); m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn); m.def("register_fn", &register_fn);
m.def("delete_fn", &delete_fn); m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id); m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup); m.def("cleanup", &cleanup);
; m.def("autotune", &autotune, pybind11::return_value_policy::reference);
m.def("launch_kernel", &launch_kernel);
init_launch(m);
init_superblocking(m);
} }

View File

@@ -1,95 +0,0 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
void init_host_stream() {
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(host_context->backend()));
}
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
if(dev_id == -1){
init_host_stream();
host_stream->synchronize();
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
stream.synchronize();
}
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){
init_host_stream();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
}
else{
C10_CUDA_CHECK(cudaSetDevice(dev_id));
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
}
}
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);

View File

@@ -0,0 +1,83 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
int64_t cdiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
int64_t largest_pow2_divisor(int64_t a){
if(a % 8 == 0) return 8;
if(a % 4 == 0) return 4;
if(a % 2 == 0) return 2;
return 1;
}
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void cuda_set_device(int64_t dev_id) {
if(dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
void init_launch(pybind11::module &m) {
m.def("cuda_set_device", &cuda_set_device);
m.def("cuda_empty_like", &cuda_empty_like);
m.def("largest_pow2_divisor", &largest_pow2_divisor);
m.def("cdiv", &cdiv);
m.def("cdiv_sum", &cdiv_sum);
m.def("synchronize", &synchronize);
}

View File

@@ -0,0 +1,117 @@
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor <int, 3>();
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor <int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(size_t h = 0; h < H; h++){
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for(size_t m = 0; m < M; m++){
for(size_t n = 0; n < N; n++){
int v = _layout[h][m][n];
if(v == 0)
continue;
int n_left= ii_left[max_width-1];
int m_top = ii_top [max_width-1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for(int nn = n_left + 1; nn < n; nn++)
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for(int k = 0; k < max_width-1; k++)
ii_left[k] = ii_left[k+1];
ii_left[max_width-1] = n;
// update ii_top ring buffer
for(int k = 0; k < max_width-1; k++)
ii_top[k][n] = ii_top[k+1][n];
ii_top[max_width-1][n] = m;
// block is too small -- skip
if(width != max_width)
continue;
// retained blocks are set to zeros
for(size_t km = 0; km < max_width; km++)
for(size_t kn = 0; kn < max_width; kn++)
{
int mm = ii_top[km][n];
int nn = ii_left[kn];
if(mm < 0 || nn < 0)
continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for(size_t h = 0; h < H; h++)
if(current[h] > 0)
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if(!to_cat.empty())
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
}
ret_t superblock(torch::Tensor layout, int start_width) {
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
int64_t H = layout.size(0);
int64_t M = layout.size(1);
int64_t N = layout.size(2);
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor<int, 3>();
for(int64_t h = 0; h < H; h++)
for(int64_t m = 0; m < M; m++)
for(int64_t n = 0; n < N; n++){
if(_layout[h][m][n] == 0)
continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for(int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}

View File

@@ -0,0 +1,50 @@
import itertools
import torch
import triton as tt
import pytest
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
return ret
def mask_tensor(x, mask, block, value = 0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
for at in [False, True]\
for bt in [False, True]\
for block in [16, 32, 64]
]
)
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
# triton result
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
rc = op(ra, rb)
# torch result
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)

17
python/tests/test_conv.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
import triton
def test_op():
torch.manual_seed(0)
DTYPE = torch.float16
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad, stride, = (1, 1), (1, 1)
dilation = (1, 1)
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
tt_c = triton.ops.conv(a, b, pad, stride)
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)

View File

@@ -0,0 +1,96 @@
import pytest
import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = th.matmul(a, b)
tt_c = tt.ops.matmul(a, b)
rtol, atol = {th.float32: (1e-4, 1e-5),
th.float16: (1e-2, 1e-3)}[DTYPE]
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms, flops/time_ms*1e-9, ret
def perf_op(dtype=th.float16, warmup=10, rep=50):
AT, BT = False, False
configs = [(N, N, N) for N in [128, 8192]]
for M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = a[::,::]
b = b[::,::]
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
print((M, N, K), TH_MS, TT_MS)

View File

@@ -0,0 +1,8 @@
import torch
import triton
def test_op(M = 1024, N = 1024, dtype = torch.float32):
x = torch.randn(M, N, dtype=dtype, device='cuda')
th_y = torch.softmax(x, dim=-1)
tt_y = triton.ops.softmax(x)
assert torch.allclose(tt_y, th_y)

View File

@@ -1,8 +1,13 @@
from .kernel import * # TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# clean-up libtriton resources # libtriton resources
import atexit import atexit
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
@atexit.register @atexit.register
def cleanup(): def cleanup():
libtriton.cleanup() libtriton.cleanup()
from .kernel import *
from . import ops

View File

@@ -15,18 +15,6 @@ codes = {
libtriton.arg_type.buffer: 'P' libtriton.arg_type.buffer: 'P'
} }
sizes = {
libtriton.arg_type.int1: 1,
libtriton.arg_type.int8: 1,
libtriton.arg_type.int32: 4,
libtriton.arg_type.int64: 8,
libtriton.arg_type.half: 2,
libtriton.arg_type.float: 4,
libtriton.arg_type.double: 8,
libtriton.arg_type.buffer: 8
}
def th_to_triton(obj): def th_to_triton(obj):
tys = { tys = {
torch.int8: 'char', torch.int8: 'char',
@@ -43,92 +31,65 @@ def th_to_triton(obj):
return [th_to_triton(x)[0] for x in obj] return [th_to_triton(x)[0] for x in obj]
return [str(obj)] return [str(obj)]
def cdiv(a, b): def cdiv(a, b):
return (a + b - 1) // b return libtriton.cdiv(a, b)
def cdiv_sum(a, b):
return torch.ops.triton.cdiv_sum(a, b)
def synchronize(device): def synchronize(device):
dev_id = device.index dev_id = device.index
dev_id = -1 if dev_id is None else dev_id dev_id = -1 if dev_id is None else dev_id
torch.ops.triton.synchronize(dev_id) libtriton.synchronize(dev_id)
def read(path):
with open(path, 'r') as f:
source = f.read()
return source
class kernel: class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): def __init__(self, src, device, defines = dict(), num_warps = [4]):
self.src = src self.src = src
self.opt = libtriton.options_space() self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()] self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
self.opt.num_warps = num_warps self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu':
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id() self.op_id = libtriton.make_op_id()
self.registered = set() libtriton.register_fn(self.op_id, self.device, self.src, self.opt)
arg_types = libtriton.get_fn_signature(self.src, self.opt) # debug mode
size = sum([sizes[x] for x in arg_types]) self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = libtriton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types])
def asm(self, mode, device, **kwargs): def __call__(self, *args, grid):
dev_id = device.index # debug mode (initialize)
# assembly mode if self.is_debug:
supported = {
'ptx': libtriton.asm_mode.ptx,
'sass': libtriton.asm_mode.sass,
}
if mode not in supported:
raise('ASM mode must be in ', supported.keys())
mode = supported[mode]
# disambiguates #defines
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
def _single_value_or_err(x, key):
if isinstance(x, list) and len(x) == 1:
return x[0]
if isinstance(x, list) and len(x) > 1:
if key in kwargs:
return kwargs[key]
raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. '
'Please supply an explicit value as a keyword argument.')
return str(x)
defines = dict()
for (D, V) in self.opt.defines:
defines[D] = _single_value_or_err(V, D)
opt = libtriton.options()
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
opt.defines = defines
# run
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
def __call__(self, *args, **kwargs):
if 'TRITON_DEBUG_MODE' in os.environ:
_args = args _args = args
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args] args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)): for i in range(len(args)):
if isinstance(args[i], torch.Tensor): if isinstance(args[i], torch.Tensor):
args[i] = torch.ops.triton.cuda_empty_like(args[i]) args[i] = libtriton.cuda_empty_like(args[i])
args[i].copy_(_args[i]) args[i].copy_(_args[i])
torch.cuda.synchronize() # initialize cuda device if necessary
for x in args: libtriton.cuda_set_device(self.device)
if isinstance(x, torch.Tensor): # pack parameters into a byte buffer
device = x.device.index params = pack(self.tys, *args)
device = -1 if device is None else device # auto-tune if necessary
break opt = libtriton.autotune(self.op_id, self.device, params, grid)
# lazily register function for device # run kernel
libtriton.register_fn((self.op_id, device), self.src, self.opt) grid = grid(opt)
# launch grid grid_0 = grid[0]
if 'grid' not in kwargs: grid_1 = 1 if len(grid) < 2 else grid[1]
raise RuntimeError('Must provide grid for kernel launch') grid_2 = 1 if len(grid) < 3 else grid[2]
grid = kwargs['grid'] libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
libtriton.register_grid((self.op_id, device), grid) # debug mode (finalize)
# re-allocate buffers for auto-tuning if self.is_debug:
if 'autotune_buf' in kwargs:
pass
# launch
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
if 'TRITON_DEBUG_MODE' in os.environ:
torch.cuda.synchronize()
for i in range(len(args)): for i in range(len(args)):
if isinstance(args[i], torch.Tensor): if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i].clone()) _args[i].copy_(args[i].clone())

View File

@@ -0,0 +1,4 @@
from .conv import _conv, conv
from .matmul import _matmul, matmul
from .softmax import _softmax, softmax
from . import blocksparse

View File

@@ -0,0 +1 @@
from .matmul import matmul

View File

@@ -0,0 +1,198 @@
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -0,0 +1,467 @@
import triton
import triton._C.libtriton as libtriton
import torch
import os
import math
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
if r >= seg_min or isempty:
segments[current+d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
##########################
# SPARSE = DENSE x DENSE #
##########################
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 64 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size*size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
F32TK = [8, 16]
#F16TK = [16]
#F16TK += [32] if is_32_multiple else []
#F16TK += [64] if is_64_multiple else []
F16TK = [64]
TK = {torch.float32: F32TK,
torch.float16: F16TK}[dtype]
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
'TK': TK, 'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': TM, 'TN': block, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': block, 'TN': TN, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return c

View File

@@ -1,16 +1,9 @@
import torch __global__ void conv(TYPE *A __noalias __readonly __aligned(16),
import triton
class _conv(torch.autograd.Function):
src = """
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16), TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16), TYPE *C __noalias __aligned(16),
float alpha, float alpha,
// equivalent matmul // equivalent matmul
int M __retune, int M, int N, int K,
int N __retune,
int K __retune,
// convolution properties // convolution properties
int pad_h, int pad_w, int stride_h, int stride_w, int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment // pointer increment
@@ -131,72 +124,3 @@ class _conv(torch.autograd.Function):
atomic_xchg(plock, 0); atomic_xchg(plock, 0);
#endif #endif
} }
"""
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride, time):
# create kernel if necessary
dtype = a.dtype
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if dtype not in _conv.kernel:
TK = 8
defines = {
'TYPE' : dtype,
'TM' : [16, 32, 64, 128],
'TN' : [16, 32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))]
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta,
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid=grid, bench=100)
return c
conv = _conv.apply
torch.manual_seed(0)
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad = (1, 1)
stride = (1, 1)
a = torch.rand((Z, CI, H, W)).cuda()
b = torch.rand((CI, R, S, CO)).cuda()
time = [None]
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
c = conv(a, b, pad, stride, time)
print((cc - c).abs().max() / max(cc.max(), c.max()))
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
#zc = torch.matmul(a,b)
#zc_ = dot(a,b)

57
python/triton/ops/conv.py Normal file
View File

@@ -0,0 +1,57 @@
import torch
import triton
import os
class _conv(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride):
# create kernel if necessary
dtype = a.dtype
device = a.device
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if (dtype, device) not in _conv.kernel:
TK = 16
defines = {
'TYPE' : dtype,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta.data_ptr(),
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply

View File

@@ -1,26 +1,24 @@
namespace src {
const char *dot =
R"(
#define STM 8 #define STM 8
#define STN 8 #define STN 8
__global__ void dot(TYPE * A __noalias __readonly __aligned(16), __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16), TYPE * C __noalias __aligned(16),
float alpha, float alpha,
int M __retune, int M,
int N __retune, int N,
int K __retune __multipleof(16), int K __multipleof(16),
int lda __multipleof(8), int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(8), int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(8), int ldc __multipleof(LDC_POW2_DIV),
int* locks) { int* locks) {
// prologue // prologue
int pid = get_program_id(0); int pid = get_program_id(0);
int pidz = get_program_id(2); int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM; int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN; int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM*gridn; int width = STM*gridn;
int stm = pid / width; int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM); int RSTM = min(gridm - stm*STM, STM);
@@ -34,14 +32,15 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int rm[TM] = pidm * TM + 0 ... TM; int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN; int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting // split-k for better parrallelism
K = K / TZ; K = K / TZ;
int rk[TK] = pidz * K + 0 ... TK; int rk[TK] = 0 ... TK;
// pointers to operands // pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa; TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb; TYPE* pb[TK, TN] = B + offb;
// prefetches operands // prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K; bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K; bool checkb[TK, TN] = rk[:, newaxis] < K;
@@ -49,14 +48,25 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE b[TK, TN] = checkb ? *pb : 0; TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK; #if (IS_TK_DIV_K==1)
bool checkb[TK, TN] = k > TK; bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk[newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @ b; acc += a @ b;
#if (IS_TK_DIV_K==1)
a = *?(checka)pa; a = *?(checka)pa;
b = *?(checkb)pb; b = *?(checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
#endif
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
} }
@@ -68,8 +78,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int rcn[TN] = pidn * TN + 0 ... TN; int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc; TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
rcn[newaxis, :] < N;
#if (TZ==1) #if (TZ==1)
*?(checkc) pc = c; *?(checkc) pc = c;
#else #else
@@ -86,6 +95,3 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
atomic_xchg(plock, 0); atomic_xchg(plock, 0);
#endif #endif
} }
)";
}

View File

@@ -0,0 +1,80 @@
import torch
import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
TM = [128]
TN = [128]
TK = [32]
TZ = 1
num_warps = [4]
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
_locks = dict()
_kernels = dict()
@staticmethod
def _call(a, b):
dtype = a.dtype
device = a.device
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
lda = a.stride(0) if is_a_row else a.stride(1)
ldb = b.stride(0) if is_b_row else b.stride(1)
ldc = c.stride(0)
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 32 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels:
defines = {
'TYPE' : dtype,
'STRIDE_AM' : 'lda' if is_a_row else '1',
'STRIDE_AK' : '1' if is_a_row else 'lda',
'STRIDE_BK' : 'ldb' if is_b_row else '1',
'STRIDE_BN' : '1' if is_b_row else 'ldb',
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'TM' : _matmul.TM,
'TN' : _matmul.TN,
'TK' : _matmul.TK,
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : is_tk_div_k
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
kernel(*args, grid=grid)
return c
@staticmethod
def forward(ctx, a, b):
c = _matmul._call(a,b)
return c
matmul = _matmul.apply

View File

@@ -0,0 +1,8 @@
__global__ void forward(TYPE* X, TYPE* Y) {
int pid = get_program_id(0);
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float x[BLOCK] = *(X + off);
float shifted[BLOCK] = exp(x - x[max]);
float sum = shifted[+];
*(Y + off) = shifted / sum;
}

View File

@@ -0,0 +1,27 @@
import torch
import triton
import os
kernels = dict()
def get_kernel(block, dtype, device):
key = (block, dtype, device)
if key not in kernels:
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
defines = {'BLOCK': block, 'TYPE': dtype}
kernels[key] = triton.kernel(src, device = device, defines = defines)
return kernels[key]
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.empty_like(x)
M, N = x.shape
kernel = get_kernel(N, x.dtype, x.device)
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
return y
softmax = _softmax.apply

View File

@@ -0,0 +1,335 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "induced-zoning",
"metadata": {},
"source": [
"# Getting Started"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
"* The basic syntax of the Triton programming language\n",
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
"* The best practices for validating and benchmarking custom ops against native reference implementations"
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
"# Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "collectible-belle",
"metadata": {},
"source": [
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
"\n",
"\n",
"```c\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // The `get_program_id(i)` returns the i-th coordinate\n",
" // of the program in the overaching SPMD context\n",
" // (a.k.a launch grid). This is what allows us to process\n",
" // different chunks of data in parallel.\n",
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
" // is similar to blockIdx.{x,y,z}\n",
" int pid = get_program_id(0);\n",
" // In Triton, arrays are first-class citizen. In other words,\n",
" // they are primitives data-types and are -- contrary to C and\n",
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
" // memory.\n",
" // In the few lines below, we create an array of `BLOCK` pointers\n",
" // whose memory values are, e.g.:\n",
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz [BLOCK] = z + offset;\n",
" float* px [BLOCK] = x + offset;\n",
" float* py [BLOCK] = y + offset;\n",
" // Simple element-wise control-flow for load/store operations can\n",
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
" // or the conditional dereferencing operator `*?(cond)ptr\n",
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
" // write-back `z`\n",
" bool check[BLOCK] = offset < N;\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
"```\n",
"\n",
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
"# Writing the Torch bindings"
]
},
{
"cell_type": "markdown",
"id": "numerical-agency",
"metadata": {},
"source": [
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
"\n",
"To create a `triton.kernel`, you only need three things:\n",
"* `source: string`: the source-code of the kernel you want to create\n",
"* `device: torch.device`: the device you want to compile this code for\n",
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
"\n",
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "sporting-keyboard",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"# here we just copy-paste the above code without the extensive comments.\n",
"# you may prefer to store it in a .c file and load it from there instead.\n",
"_src = \"\"\"\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // program id\n",
" int pid = get_program_id(0);\n",
" // create arrays of pointers\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz[BLOCK] = z + offset;\n",
" float* px[BLOCK] = x + offset;\n",
" float* py[BLOCK] = y + offset;\n",
" // bounds checking\n",
" bool check[BLOCK] = offset < N;\n",
" // write-back\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
" \"\"\"\n",
"# This function returns a callable `triton.kernel` object\n",
"# created from the above source code.\n",
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
"# We compile the kernel with -DBLOCK=1024\n",
"_kernels = dict()\n",
"def make_add_kernel(device):\n",
" if device not in _kernels:\n",
" defines = {'BLOCK': 1024}\n",
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[device]\n",
"\n",
"# This is a standard torch custom autograd Function\n",
"# The only difference is that we can now use the above kernel\n",
"# in the `forward` and `backward` functions.`\n",
"class _add(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x, y):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" # *allocate output*\n",
" z = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # this is a function which takes compilation parameters `opt`\n",
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
" # triton.cdiv is a shortcut for ceil division:\n",
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
" N = z.shape[0]\n",
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
" # *launch kernel*:\n",
" # pointer to the data of torch tensors can be retrieved with\n",
" # the `.data_ptr()` method\n",
" kernel = make_add_kernel(z.device)\n",
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
" return z\n",
"# Just like we standard PyTorch ops\n",
"# We use the `.apply` method to create a \n",
"# callable object for our function\n",
"add = _add.apply"
]
},
{
"cell_type": "markdown",
"id": "separated-polyester",
"metadata": {},
"source": [
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
"# Writing a Unit Test"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "supported-ribbon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"The maximum difference between torch and triton is 0.0\n"
]
}
],
"source": [
"torch.manual_seed(0)\n",
"x = torch.rand(98432, device='cuda')\n",
"y = torch.rand(98432, device='cuda')\n",
"za = x + y\n",
"zb = add(x, y)\n",
"print(za)\n",
"print(zb)\n",
"print(f'The maximum difference between torch and triton is '\n",
" f'{torch.max(torch.abs(za - zb))}')"
]
},
{
"cell_type": "markdown",
"id": "otherwise-canadian",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
"# Writing a Benchmark"
]
},
{
"cell_type": "markdown",
"id": "historic-glass",
"metadata": {},
"source": [
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-luxembourg",
"metadata": {},
"outputs": [],
"source": [
"# We now want to benchmark the performance of `add`\n",
"# Against that of PyTorch for increasing vector sizes\n",
"def do_bench(fn, warmup = 10, rep = 50):\n",
" start_event = torch.cuda.Event(enable_timing=True)\n",
" end_event = torch.cuda.Event(enable_timing=True)\n",
" ret = fn()\n",
" for i in range(warmup):\n",
" fn()\n",
" torch.cuda.synchronize()\n",
" start_event.record()\n",
" for i in range(rep):\n",
" fn()\n",
" end_event.record()\n",
" torch.cuda.synchronize()\n",
" time_ms = start_event.elapsed_time(end_event) / rep\n",
" return time_ms"
]
},
{
"cell_type": "markdown",
"id": "hairy-claim",
"metadata": {},
"source": [
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "pleasant-valley",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"131072 0.020 0.003\n",
"262144 0.019 0.004\n",
"524288 0.016 0.016\n",
"1048576 0.033 0.033\n",
"2097152 0.071 0.070\n",
"4194304 0.142 0.144\n",
"8388608 0.287 0.286\n",
"16777216 0.572 0.568\n",
"33554432 1.139 1.110\n"
]
}
],
"source": [
"for N in [2**i for i in range(17, 26, 1)]:\n",
" x = torch.rand(N, device='cuda')\n",
" y = torch.rand(N, device='cuda')\n",
" triton_ms = do_bench(lambda: add(x, y))\n",
" torch_ms = do_bench(lambda: x + y)\n",
" # print the performance of triton and torch as well as the achieved bandwidth\n",
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
},
{
"cell_type": "markdown",
"id": "juvenile-supplement",
"metadata": {},
"source": [
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "agreed-backing",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,3 +0,0 @@
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/common")
add_subdirectory(bench)
add_subdirectory(unit)

View File

@@ -1,39 +0,0 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "conv.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs = {
// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1},
// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1},
{1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1},
// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1}
};
int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w;
for(const auto& c: configs){
std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c;
std::cout << "// " << c ;
for(auto perf: bench_conv(context, stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -1,36 +0,0 @@
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs = {
{{4096*4096}, {0}, {0}},
{{4096, 4096}, {0, 1}, {0, 1}},
{{4096, 4096}, {0, 1}, {1, 0}},
{{4096, 4096}, {1, 0}, {0, 1}},
{{4096, 4096}, {1, 0}, {1, 0}},
{{256, 256, 256}, {0, 1, 2}, {0, 1, 2}},
{{256, 256, 256}, {0, 1, 2}, {0, 2, 1}},
{{256, 256, 256}, {1, 0, 2}, {1, 2, 0}},
{{256, 256, 256}, {1, 2, 0}, {1, 0, 2}}
// {{256, 256, 256}, {2, 0, 1}, {0, 1, 2}},
// {{256, 256, 256}, {2, 1, 0}, {0, 2, 1}}
};
// does the work
std::vector<int32_t> shape;
std::vector<int32_t> ord_x, ord_y;
for(const auto& c: configs){
std::tie(shape, ord_x, ord_y) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_copy_nd(context, stream, HALF, shape, ord_x, ord_y))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -1,54 +0,0 @@
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/copy.h"
#include "util.h"
#include "cuda/cublas.h"
std::vector<double> do_bench(drv::stream* stream, int32_t N){
typedef float NumericT;
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
// create inputs
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, N*dt_nbytes));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, N*dt_nbytes));
// create options
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TN", {"128"}});
opt.num_warps = {1, 2, 4, 8};
// create function
rt::function function(src::copy1d, opt);
// benchmark available libraries
std::vector<double> result;
auto gbps = [&](double ns) { return 2*N*dt_nbytes / (ns * 1e-9) * 1e-9; };
// triton
double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, N}, grid1d(N), stream);}, stream);
result.push_back(gbps(triton_ns));
// done
return result;
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<int> config_t;
std::vector<config_t> configs = { 1024*1024*32 };
int N;
for(const auto& c: configs){
std::tie(N) = c;
std::cout << "// " << c << std::flush;
for(auto perf: do_bench(stream, N))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -1,72 +0,0 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "dot.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
std::vector<config_t> tmp = {
// config_t{ord, x[0], x[1], 128, 128, 128},
// config_t{ord, x[0], x[1], 256, 256, 256},
// config_t{ord, x[0], x[1], 384, 384, 384},
// config_t{ord, x[0], x[1], 512, 512, 512},
// config_t{ord, x[0], x[1], 768, 768, 768},
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
config_t{ord, x[0], x[1], 8192, 8192, 8192},
// config_t{ord, x[0], x[1], 256, 16, 256},
// config_t{ord, x[0], x[1], 512, 16, 512},
// config_t{ord, x[0], x[1], 768, 16, 768},
// config_t{ord, x[0], x[1], 1024, 16, 1024},
// config_t{ord, x[0], x[1], 1280, 16, 1280},
// config_t{ord, x[0], x[1], 1536, 16, 1536},
// config_t{ord, x[0], x[1], 2048, 16, 2048},
// config_t{ord, x[0], x[1], 3072, 16, 3072},
// config_t{ord, x[0], x[1], 4096, 16, 4096},
// config_t{ord, x[0], x[1], 5120, 16, 5120},
// config_t{ord, x[0], x[1], 6144, 16, 6144},
// config_t{ord, x[0], x[1], 7168, 16, 7168},
// config_t{ord, x[0], x[1], 64, 64, 4096},
// config_t{ord, x[0], x[1], 64, 64, 8192},
// config_t{ord, x[0], x[1], 64, 64, 16384},
// config_t{ord, x[0], x[1], 64, 64, 32768},
// config_t{ord, x[0], x[1], 64, 64, 65536},
// config_t{ord, x[0], x[1], 64, 64, 131072}
// config_t{ord, x[0], x[1], 127008, 768, 576},
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
// config_t{ord, x[0], x[1], 16, 2048, 2048},
// config_t{ord, x[0], x[1], 32, 2048, 2048},
// config_t{ord, x[0], x[1], 64, 2048, 2048},
// config_t{ord, x[0], x[1], 128, 2048, 2048},
// config_t{ord, x[0], x[1], 7000, 2048, 2048},
// config_t{ord, x[0], x[1], 16, 4096, 4096},
// config_t{ord, x[0], x[1], 32, 4096, 4096},
// config_t{ord, x[0], x[1], 64, 4096, 4096},
// config_t{ord, x[0], x[1], 128, 4096, 4096},
// config_t{ord, x[0], x[1], 7000, 4096, 4096}
};
configs.insert(configs.end(), tmp.begin(), tmp.end());
}
// does the work
std::vector<int> ord;
bool AT, BT;
int32_t M, N, K;
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c ;
for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -1,140 +0,0 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/conv.h"
#include "cuda/cublas.h"
#include "util.h"
enum run_mode_t {
BENCH,
TEST
};
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
void triton_conv(drv::context* context, drv::stream* stream,
int Z, int CI, int H, int W, int CO, int R, int S,
int pad_h, int pad_w, int stride_h, int stride_w,
run_mode_t mode, std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::device* device = context->device();
int P = (H + 2*pad_h - R)/stride_h + 1;
int Q = (W + 2*pad_w - S)/stride_w + 1;
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CI*H*W*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*CO*dt_nbytes));
auto ddelta = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*4));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
std::vector<int32_t> hdelta(CI*R*S);
int TK = 16;
for(int i = 0; i < hdelta.size(); i++){
int s = i % S;
int cr = i / S;
int r = cr % R;
int c = cr / R;
int nexti = i + TK;
int nexts = nexti % S;
int nextcr = nexti / S;
int nextr = nextcr % R;
int nextc = nextcr / R;
hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s);
}
stream->write(&*ddelta, true, 0, hdelta);
// macros
rt::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.defines.push_back({"TZ", {"1"}});
opt.defines.push_back({"RR", {std::to_string(R)}});
opt.defines.push_back({"SS", {std::to_string(S)}});
opt.defines.push_back({"PP", {std::to_string(P)}});
opt.defines.push_back({"QQ", {std::to_string(Q)}});
opt.defines.push_back({"HH", {std::to_string(H)}});
opt.defines.push_back({"WW", {std::to_string(W)}});
opt.num_warps = {4};
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
rt::add_arg(oss, *db->cu());
rt::add_arg(oss, *dc->cu());
rt::add_arg(oss, (float)1);
rt::add_arg(oss, Z*P*Q);
rt::add_arg(oss, CO);
rt::add_arg(oss, CI*R*S);
rt::add_arg(oss, pad_h);
rt::add_arg(oss, pad_w);
rt::add_arg(oss, stride_h);
rt::add_arg(oss, stride_w);
rt::add_arg(oss, *ddelta->cu());
rt::add_arg(oss, W*H*CI);
rt::add_arg(oss, W*H);
rt::add_arg(oss, W);
rt::add_arg(oss, 1);
rt::add_arg(oss, CO*S*R);
rt::add_arg(oss, CO*S);
rt::add_arg(oss, CO);
rt::add_arg(oss, 1);
rt::add_arg(oss, Q*P*CO);
rt::add_arg(oss, Q*P);
rt::add_arg(oss, Q);
rt::add_arg(oss, 1);
// kernels
rt::function function(src::conv, opt);
auto grid = [Z,P,Q,CO](const rt::options_t& x) {
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
ceil(CO , x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
bench.push_back(tflops(triton_ns));
}
std::vector<double> bench_conv(drv::context* context, drv::stream* stream, dtype_t dtype,
int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S,
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF: triton_conv<half_float::half>(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case FLOAT: triton_conv<float>(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case DOUBLE: triton_conv<double>(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
default: break;
}
return bench;
}

View File

@@ -1,186 +0,0 @@
#include "src/copy.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "util.h"
int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides) {
int32_t res = 0;
for(size_t d = 0; d < idx.size(); d++)
res += idx[d] * strides[d];
return res;
}
enum run_mode_t {
BENCH,
TEST
};
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
void cc_copy_nd(const std::vector<T>& x, std::vector<T>& y,
const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
size_t rank = shape.size();
// strides for x
std::vector<int32_t> x_strides(shape.size());
for(size_t d = 0; d < rank; d++)
x_strides[x_order[d]] = (d == 0) ? 1 : (x_strides[x_order[d-1]] * shape[x_order[d-1]]);
// strides for y
std::vector<int32_t> y_strides(shape.size());
for(size_t d = 0; d < rank; d++)
y_strides[y_order[d]] = (d == 0) ? 1 : (y_strides[y_order[d-1]] * shape[y_order[d-1]]);
// copy 1d
if(rank == 1)
for(int32_t i = 0; i < shape[0]; i++)
y[off({i}, y_strides)] = x[off({i}, x_strides)];
// copy 2d
if(rank == 2)
for(int32_t i = 0; i < shape[0]; i++)
for(int32_t j = 0; j < shape[1]; j++)
y[off({i, j}, y_strides)] = x[off({i, j}, x_strides)];
// copy 3d
if(rank == 3)
for(int32_t i = 0; i < shape[0]; i++)
for(int32_t j = 0; j < shape[1]; j++)
for(int32_t k = 0; k < shape[2]; k++)
y[off({i, j, k}, y_strides)] = x[off({i, j, k}, x_strides)];
}
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<typename T>
void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order,
std::vector<std::vector<std::string>> TS,
run_mode_t mode, std::vector<double>& bench, bool &test) {
std::string ty = to_string<T>::value;
size_t dtsize = sizeof(T);
drv::device* device = context->device();
// rank
size_t rank = shape.size();
// size
size_t size = 1;
for(int32_t d: shape)
size *= d;
std::vector<std::string> shapename = {"S0", "S1", "S2"};
// strides for x
std::vector<std::string> x_strides = {"1"};
for(size_t d = 0; d < rank - 1; d++)
x_strides.push_back(x_strides[d] + " * " + shapename[x_order[d]]);
// strides for y
std::vector<std::string> y_strides = {"1"};
for(size_t d = 0; d < rank - 1; d++)
y_strides.push_back(y_strides[d] + " * " + shapename[y_order[d]]);
// create inputs
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
// create options
rt::options_space_t opt;
// macros
opt.defines.push_back({"TYPE", {ty}});
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
if(TS.empty())
TS = tile_nd(rank);
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"TS" + std::to_string(d), TS[d]});
opt.num_warps = {4};
// kernel
rt::function function(src::copy_nd[rank - 1], opt);
std::stringstream oss;
rt::add_arg(oss, *dx->cu());
rt::add_arg(oss, *dy->cu());
rt::add_arg(oss, (uint32_t)shape[0]);
if(shape.size() > 1) rt::add_arg(oss, (uint32_t)shape[1]);
if(shape.size() > 2) rt::add_arg(oss, (uint32_t)shape[2]);
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape, ts);
// metrics
if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
bench.push_back(gbps(triton_ns));
}
// test triton
if(mode == TEST){
std::vector<T> hx(size);
std::vector<T> hy(size);
std::vector<T> ry(size);
for(size_t i = 0; i < hx.size(); i++)
hx[i] = static_cast<T>((float)rand()/RAND_MAX);
stream->write(&*dx, true, 0, hx);
function((void**)oss.str().data(), oss.str().size(), grid, stream, device);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cc_copy_nd(hx, ry, shape, x_order, y_order);
test = testing::diff(hy, ry);
}
}
std::vector<double> bench_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF:
triton_copy_nd<half_float::half>(context, stream, shape, x_order, y_order, {}, BENCH, bench, test);
break;
case FLOAT:
triton_copy_nd<float>(context, stream, shape, x_order, y_order, {}, BENCH, bench, test);
break;
default: break;
}
return bench;
}
bool test_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector<int32_t>& shape,
const std::vector<int32_t>& TS,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
std::vector<double> bench;
bool test;
std::vector<std::vector<std::string>> TSS;
for(int32_t d: TS)
TSS.push_back({std::to_string(d)});
switch(dtype){
case HALF:
triton_copy_nd<half_float::half>(context, stream, shape, x_order, y_order, TSS, TEST, bench, test);
break;
case FLOAT:
triton_copy_nd<float>(context, stream, shape, x_order, y_order, TSS, TEST, bench, test);
break;
default: break;
}
return test;
}

View File

@@ -1,218 +0,0 @@
/* Copyright 2019 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <algorithm>
#include <vector>
#include <cassert>
#include "forward.h"
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/driver/context.h"
#include "triton/driver/error.h"
#include "triton/tools/bench.hpp"
class cublas {
private:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
triton::driver::check(res);
return res;
}
public:
static bool cublasinit();
static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m);
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId);
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const void *alpha, const void *A, cudaDataType Atype, int lda,
const void *B, cudaDataType Btype, int ldb, const void *beta,
void *C, cudaDataType Ctype, int ldc,
cudaDataType computeType, cublasGemmAlgo_t algo);
private:
static void* so_;
static void* cublasGetStream_v2_;
static void* cublasSetStream_v2_;
static void* cublasCreate_v2_;
static void* cublasGemmEx_;
static void* cublasSetMathMode_;
};
void* cublas::so_;
void* cublas::cublasGetStream_v2_;
void* cublas::cublasSetStream_v2_;
void* cublas::cublasCreate_v2_;
void* cublas::cublasGemmEx_;
void* cublas::cublasSetMathMode_;
bool cublas::cublasinit() {
if(so_==nullptr)
so_ = dlopen("libcublas.so", RTLD_LAZY);
return so_ != nullptr;
}
cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a)
{ return f_impl<cublas::cublasinit>(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); }
cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a)
{ return f_impl<cublas::cublasinit>(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); }
cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
const void *alpha, const void *A, cudaDataType Atype, int lda,
const void *B, cudaDataType Btype, int ldb, const void *beta,
void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) {
return f_impl<cublas::cublasinit>(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo);
}
cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) {
return f_impl<cublas::cublasinit>(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h);
}
cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) {
return f_impl<cublas::cublasinit>(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m);
}
inline cublasGemmAlgo_t cublasGemmFastest(
triton::driver::stream* stream,
cublasHandle_t handle, cudaDataType cudt,
cublasOperation_t AT, cublasOperation_t BT,
int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc) {
// initialize list of cublas algorithms
static std::vector<cublasGemmAlgo_t> algorithms;
if(algorithms.empty()) {
// non-tensor ops
for(int i = -1; i < 24; i++)
algorithms.push_back((cublasGemmAlgo_t)i);
// tensor ops
for(int i = 99; i < 116; i++)
algorithms.push_back((cublasGemmAlgo_t)i);
}
// cache to avoid re-benchmarking
typedef std::tuple<cudaDataType_t,
cublasOperation_t, cublasOperation_t,
int32_t, int32_t, int32_t> key_t;
static std::map<key_t, cublasGemmAlgo_t> cache;
key_t key(cudt, AT, BT, M, N, K);
// benchmark algorithms if necessary
if(cache.find(key) == cache.end()){
std::vector<double> times;
for(cublasGemmAlgo_t a: algorithms) {
cublasStatus_t status;
double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT,
M, N, K,
alpha, (const void*)A, cudt, lda,
(const void*)B, cudt, ldb,
beta, (void*)C, cudt, ldc, CUDA_R_32F,
a); }, stream);
if(status != CUBLAS_STATUS_SUCCESS)
nanosec = INFINITY;
}
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
assert(times[argmin] != INFINITY);
cache.insert({key, algorithms[argmin]});
}
// return best algorithm
return cache.at(key);
}
/* Get cuBLAS handle */
inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
static std::map<CUstream, cublasHandle_t> cache;
CUstream key = *stream->cu();
// create handle if necessary
if(cache.find(key) == cache.end()) {
cublasHandle_t handle;
if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("Error: could not create cuBLAS handle");
cublas::cublasSetStream_v2(handle, key);
cache.insert({key, handle});
}
// return handle for the stream
return cache.at(key);
}
/* Simplified API for default GEMM */
inline void cublasGemm(cublasDataType_t dtype,
triton::driver::stream* stream,
bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
void* alpha, triton::driver::buffer* A, int32_t lda,
triton::driver::buffer* B, int32_t ldb,
void* beta, triton::driver::buffer* C, int32_t ldc,
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
// get handle
static cublasHandle_t handle = cublasGetHandle(stream);
// set math mode
if(dtype == CUDA_R_16F)
cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
// cuda types
static const std::map<char, cublasOperation_t> cu_op = {
{false, CUBLAS_OP_N},
{true, CUBLAS_OP_T}
};
cublasOperation_t opa = cu_op.at(AT);
cublasOperation_t opb = cu_op.at(BT);
// benchmark fastest
if(fastest)
*fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
else {
// execute supplied algo
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
alpha, (const void*)*A->cu(), dtype, lda,
(const void*)*B->cu(), dtype, ldb,
beta, (void*)*C->cu(), dtype, ldc, CUDA_R_32F, algo);
}
}

View File

@@ -1,105 +0,0 @@
#ifndef _COMMON_CUDA_FORWARDS_H_
#define _COMMON_CUDA_FORwARDS_H_
struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
struct CUstream_st;
typedef struct CUstream_st *cudaStream_t;
/* CUBLAS status type returns */
typedef enum{
CUBLAS_STATUS_SUCCESS =0,
CUBLAS_STATUS_NOT_INITIALIZED =1,
CUBLAS_STATUS_ALLOC_FAILED =3,
CUBLAS_STATUS_INVALID_VALUE =7,
CUBLAS_STATUS_ARCH_MISMATCH =8,
CUBLAS_STATUS_MAPPING_ERROR =11,
CUBLAS_STATUS_EXECUTION_FAILED=13,
CUBLAS_STATUS_INTERNAL_ERROR =14,
CUBLAS_STATUS_NOT_SUPPORTED =15,
CUBLAS_STATUS_LICENSE_ERROR =16
} cublasStatus_t;
/*For different GEMM algorithm */
typedef enum {
CUBLAS_GEMM_DFALT = -1,
CUBLAS_GEMM_DEFAULT = -1,
CUBLAS_GEMM_ALGO0 = 0, // maxwell_sgemm_32x128_nt
CUBLAS_GEMM_ALGO1 = 1, // maxwell_sgemm_64x64_nt
CUBLAS_GEMM_ALGO2 = 2, // maxwell_sgemm_128x32_nt
CUBLAS_GEMM_ALGO3 = 3, // maxwell_sgemm_128x64_nt
CUBLAS_GEMM_ALGO4 = 4, // maxwell_sgemm_128x128_nt
CUBLAS_GEMM_ALGO5 = 5,
CUBLAS_GEMM_ALGO6 = 6,
CUBLAS_GEMM_ALGO7 = 7,
CUBLAS_GEMM_ALGO8 = 8,
CUBLAS_GEMM_ALGO9 = 9,
CUBLAS_GEMM_ALGO10 = 10,
CUBLAS_GEMM_ALGO11 = 11,
CUBLAS_GEMM_ALGO12 = 12,
CUBLAS_GEMM_ALGO13 = 13,
CUBLAS_GEMM_ALGO14 = 14,
CUBLAS_GEMM_ALGO15 = 15,
CUBLAS_GEMM_ALGO16 = 16,
CUBLAS_GEMM_ALGO17 = 17,
CUBLAS_GEMM_ALGO18 = 18, //sliced 32x32
CUBLAS_GEMM_ALGO19 = 19, //sliced 64x32
CUBLAS_GEMM_ALGO20 = 20, //sliced 128x32
CUBLAS_GEMM_ALGO21 = 21, //sliced 32x32 -splitK
CUBLAS_GEMM_ALGO22 = 22, //sliced 64x32 -splitK
CUBLAS_GEMM_ALGO23 = 23, //sliced 128x32 -splitK
CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99,
CUBLAS_GEMM_DFALT_TENSOR_OP = 99,
CUBLAS_GEMM_ALGO0_TENSOR_OP = 100,
CUBLAS_GEMM_ALGO1_TENSOR_OP = 101,
CUBLAS_GEMM_ALGO2_TENSOR_OP = 102,
CUBLAS_GEMM_ALGO3_TENSOR_OP = 103,
CUBLAS_GEMM_ALGO4_TENSOR_OP = 104,
CUBLAS_GEMM_ALGO5_TENSOR_OP = 105,
CUBLAS_GEMM_ALGO6_TENSOR_OP = 106,
CUBLAS_GEMM_ALGO7_TENSOR_OP = 107,
CUBLAS_GEMM_ALGO8_TENSOR_OP = 108,
CUBLAS_GEMM_ALGO9_TENSOR_OP = 109,
CUBLAS_GEMM_ALGO10_TENSOR_OP = 110,
CUBLAS_GEMM_ALGO11_TENSOR_OP = 111,
CUBLAS_GEMM_ALGO12_TENSOR_OP = 112,
CUBLAS_GEMM_ALGO13_TENSOR_OP = 113,
CUBLAS_GEMM_ALGO14_TENSOR_OP = 114,
CUBLAS_GEMM_ALGO15_TENSOR_OP = 115
} cublasGemmAlgo_t;
typedef enum cudaDataType_t
{
CUDA_R_16F= 2, /* real as a half */
CUDA_C_16F= 6, /* complex as a pair of half numbers */
CUDA_R_32F= 0, /* real as a float */
CUDA_C_32F= 4, /* complex as a pair of float numbers */
CUDA_R_64F= 1, /* real as a double */
CUDA_C_64F= 5, /* complex as a pair of double numbers */
CUDA_R_8I = 3, /* real as a signed char */
CUDA_C_8I = 7, /* complex as a pair of signed char numbers */
CUDA_R_8U = 8, /* real as a unsigned char */
CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */
CUDA_R_32I= 10, /* real as a signed int */
CUDA_C_32I= 11, /* complex as a pair of signed int numbers */
CUDA_R_32U= 12, /* real as a unsigned int */
CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */
} cudaDataType;
typedef cudaDataType cublasDataType_t;
typedef enum {
CUBLAS_OP_N=0,
CUBLAS_OP_T=1,
CUBLAS_OP_C=2,
CUBLAS_OP_HERMITAN=2, /* synonym if CUBLAS_OP_C */
CUBLAS_OP_CONJG=3 /* conjugate */
} cublasOperation_t;
/*Enum for default math mode/tensor operation*/
typedef enum {
CUBLAS_DEFAULT_MATH = 0,
CUBLAS_TENSOR_OP_MATH = 1
} cublasMath_t;
#endif

View File

@@ -1,215 +0,0 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include <tuple>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/dot.h"
#include "cuda/cublas.h"
#include "util.h"
template<class T, bool AT, bool BT>
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]);
c[m*N + n] = static_cast<T>(acc);
}
}
template<class T>
void cc_dot(bool AT_, bool BT_, size_t M, size_t N, size_t K,
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT_ && BT_)
cc_dot<T, true, true>(c, a, b, M, N, K);
else if(AT_ && !BT_)
cc_dot<T, true, false>(c, a, b, M, N, K);
else if(!AT_ && BT_)
cc_dot<T, false, true>(c, a, b, M, N, K);
else
cc_dot<T, false, false>(c, a, b, M, N, K);
}
enum run_mode_t {
BENCH,
TEST
};
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
int32_t TM, int32_t TN, int32_t TK, int32_t nwarp,
const std::vector<int>& a_order, const std::vector<int>& b_order,
run_mode_t mode, std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::device* device = context->device();
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
int32_t ldc = N;
std::vector<std::string> sa = { "1", "lda" };
std::vector<std::string> sb = { "1", "ldb" };
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
// macros
rt::options_space_t opts;
// A access patterns
opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
// B access patterns
opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
// data-type
opts.defines.push_back({"TYPE", {ty}});
// tile sizes
if(mode == TEST) {
opts.defines.push_back({"TM", {std::to_string(TM)}});
opts.defines.push_back({"TN", {std::to_string(TN)}});
opts.defines.push_back({"TK", {std::to_string(TK)}});
opts.defines.push_back({"TZ", {"1"}});
opts.num_warps = {nwarp};
}
if(mode == BENCH) {
opts.defines.push_back({"TM", {"128"}});
opts.defines.push_back({"TN", {"128"}});
opts.defines.push_back({"TK", {"32"}});
opts.defines.push_back({"TZ", {"1"}});
opts.num_warps = {4};
}
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
rt::add_arg(oss, *db->cu());
rt::add_arg(oss, *dc->cu());
rt::add_arg(oss, (float)1);
rt::add_arg(oss, M);
rt::add_arg(oss, N);
rt::add_arg(oss, K);
rt::add_arg(oss, lda);
rt::add_arg(oss, ldb);
rt::add_arg(oss, ldc);
rt::add_arg(oss, *dlocks->cu());
// kernel
rt::function function(src::dot, opts);
// grid
auto grid = [M, N](const rt::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM"))*
ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
// metrics
if(mode == BENCH){
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
bench.push_back(tflops(triton_ns));
// cublas
if(cublas::cublasinit()){
T alpha(static_cast<double>(1));
T beta(static_cast<double>(0));
cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, !AT, !BT, M, N, K,
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
ldc); }, stream);
bench.push_back(tflops(cublas_ms));
}
}
// rt::options_t opt;
// for(auto &x: opts.defines)
// opt.defines[x.first] = x.second[0];
// opt.num_warps = 1;
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
// test triton
if(mode == TEST){
srand(0);
// initialize buffers
std::vector<T> hc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
// copy buffer
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// run kernel
function((void**)oss.str().data(), oss.str().size(), grid, stream, device);
// write back
stream->synchronize();
// compare with CPU
stream->read(&*dc, true, 0, hc);
std::vector<T> rc(hc.size());
cc_dot(AT, BT, M, N, K, rc, ha, hb);
test = testing::diff(hc, rc);
}
}
std::vector<double> bench_dot(drv::context* context, drv::stream* stream,
dtype_t dtype, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
const std::vector<int>& a_order, const std::vector<int>& b_order) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF: triton_dot<half_float::half>(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
case FLOAT: triton_dot<float>(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
case DOUBLE: triton_dot<double>(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
default: break;
}
return bench;
}
bool test_dot(drv::context* context, drv::stream* stream,
dtype_t dtype, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
const std::vector<int>& a_order, const std::vector<int>& b_order,
int32_t TM, int32_t TN, int32_t TK, size_t nwarp) {
std::vector<double> bench;
bool test = false;
switch(dtype){
case HALF: triton_dot<half_float::half>(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
case FLOAT: triton_dot<float>(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
case DOUBLE: triton_dot<double>(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
default: break;
}
return test;
}

View File

@@ -1,173 +0,0 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/reduce.h"
#include "util.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
template<class T>
void cc_reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
assert(axis <= shapes.size() - 1);
// remove shape at index axis to get outer dimensions
std::vector<int> outer = shapes;
outer.erase(outer.begin() + axis);
if(outer.empty())
outer.push_back(1);
// retrieve shape at index axis to get inner dimension
int inner = shapes[axis];
// accumualtion function
auto acc = get_accumulator<T>(op);
// iterate over outer dimensions
_loop_nest(outer, [&](const std::vector<int>& y_idx) {
T ret = 0;
auto x_idx = y_idx;
x_idx.insert(x_idx.begin() + axis, 0);
// accumulate over inner dimensions
for(int z = 0; z < inner; z++){
x_idx[axis] = z;
ret = acc(ret, x[offset(x_idx, shapes)]);
}
y[offset(y_idx, outer)] = ret;
});
}
enum run_mode_t {
BENCH,
TEST
};
void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vector<int32_t>& shape_x,
int axis, reduce_op_t op,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order,
std::vector<std::vector<std::string>> TS,
run_mode_t mode, std::vector<double>& bench, bool &test) {
typedef float NumericT;
std::string ty = "float";
size_t dtsize = sizeof(NumericT);
drv::device* device = context->device();
// shape
std::vector<int> shape_y = shape_x;
shape_y.erase(shape_y.begin() + axis);
// rank
int rank_x = shape_x.size();
int rank_y = shape_y.size();
// size
size_t size_x = 1;
for(int32_t d: shape_x)
size_x *= d;
size_t size_y = 1;
for(int32_t d: shape_y)
size_y *= d;
// strides for x
std::vector<std::string> x_shapename = {"S0", "S1", "S2"};
std::vector<std::string> x_strides = {"1"};
for(int d = 0; d < rank_x - 1; d++)
x_strides.push_back(x_strides[d] + " * " + x_shapename[x_order[d]]);
// strides for y
std::vector<std::string> y_shapename = x_shapename;
y_shapename.erase(y_shapename.begin() + axis);
std::vector<std::string> y_strides = {"1"};
for(int d = 0; d < rank_y - 1; d++)
y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]);
// options
rt::options_space_t opts;
opts.defines.push_back({"TYPE", {ty}});
for(int d = 0; d < rank_x; d++)
opts.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
for(int d = 0; d < rank_y; d++)
opts.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
if(TS.empty())
TS = tile_nd(rank_x);
for(int d = 0; d < rank_x; d++)
opts.defines.push_back({"TS" + std::to_string(d), TS[d]});
std::vector<size_t> axy;
for(int d = 0; d < rank_x; d++)
if(d != axis)
axy.push_back(d);
for(int d = 0; d < rank_y; d++)
opts.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}});
for(int d = 0; d < rank_y; d++)
opts.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}});
std::string RED = "";
for(int n = 0; n < rank_x; n++){
if(n > 0)
RED += ", ";
RED += (n==axis) ? to_str(op) : ":";
}
opts.defines.push_back({"RED", {RED}});
opts.num_warps = {2};
// kernel
rt::function function(src::reduce_nd[rank_x - 1], opts);
// input buffers
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_x*dtsize));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
// grid
std::stringstream oss;
rt::add_arg(oss, *dx->cu());
rt::add_arg(oss, *dy->cu());
rt::add_arg(oss, (uint32_t)shape_x[0]);
if(shape_x.size() > 1) rt::add_arg(oss, (uint32_t)shape_x[1]);
if(shape_x.size() > 2) rt::add_arg(oss, (uint32_t)shape_x[2]);
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape_x, ts);
// metrics
if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
bench.push_back(gbps(triton_ns));
}
// rt::options_t opt;
// for(auto &x: opts.defines)
// opt.defines[x.first] = x.second[0];
// opt.num_warps = 1;
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
// test triton
if(mode == TEST){
std::vector<NumericT> hy(size_y);
std::vector<NumericT> ry(size_y);
std::vector<NumericT> hx(size_x);
init_zeros(hy);
init_rand(hx);
stream->write(&*dx, true, 0, hx);
function((void**)oss.str().data(), oss.str().size(), grid, stream, device);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cc_reduce_nd(ry, hx, op, axis, shape_x);
test = testing::diff(hy, ry);
}
}
bool do_test(drv::context* context, drv::stream* stream, std::vector<int> shape, int axis, reduce_op_t op, int nwarp){
std::vector<double> bench;
bool test;
std::vector<std::vector<std::string>> TSS;
for(int32_t d: shape)
TSS.push_back({std::to_string(d)});
triton_reduce_nd(context, stream, shape, axis, op, {0, 1, 2}, {0, 1, 2}, TSS, TEST, bench, test);
return test;
}

View File

@@ -1,134 +0,0 @@
namespace src {
const char *conv =
R"(
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M __retune,
int N __retune,
int K __retune,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs [TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr [TK] = rcir % RR;
int rci [TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz [:, newaxis] * lda_z +
rci[newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE* pa[TM, TK] = A + offa;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr [:, newaxis] * ldb_r +
rs [:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE* pb[TK, TN] = B + offb;
// pointers to delta
int* padelta[TK] = ADELTA + rk;
int adelta[TK] = *padelta;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkbn[TK, TN] = rn[newaxis, :] < N;
bool checkbk[TK, TN] = rk[:, newaxis] < K;
bool checkb[TK, TN] = checkbn && checkbk;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
// increment A
pa += adelta[newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkbk = k > TK;
bool checkb[TK, TN] = checkbn && checkbk;
a = checka ? *pa : 0;
b = *?(checkb)pb;
padelta += TK;
adelta = *padelta;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co +
rp [:, newaxis] * ldc_p +
rm [:, newaxis] * 1;
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";
}

View File

@@ -1,67 +0,0 @@
#ifndef _TRITON_TEST_SRC_COPY_H_
#define _TRITON_TEST_SRC_COPY_H_
namespace src {
const char *copy1d =
R"(
void copy1d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __readonly __aligned(16),
int S0 __retune) {
int pid0 = get_program_id(0);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
TYPE* px[TS0] = X + rs0;
TYPE* py[TS0] = Y + rs0;
*py = *px;
}
)";
const char *copy2d =
R"(
void copy2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0 __multipleof(8) __retune,
int S1 __multipleof(8) __retune) {
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
bool in_bounds[TS0, TS1] = rs0[:, newaxis] < S0 && rs1[newaxis, :] < S1;
TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0 + rs1[newaxis, :] * STRIDE_XS1;
TYPE* py[TS0, TS1] = Y + rs0[:, newaxis] * STRIDE_YS0 + rs1[newaxis, :] * STRIDE_YS1;
*?(in_bounds)py = *?(in_bounds)px;
}
)";
const char *copy3d =
R"(
void copy3d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0 __multipleof(8) __retune,
int S1 __multipleof(8) __retune,
int S2 __multipleof(8) __retune) {
// program id
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pid2 = get_program_id(2);
// ranges
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
int rs2[TS2] = pid2 * TS2 + 0 ... TS2;
// X pointers
TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0
+ rs1[newaxis, :, newaxis] * STRIDE_XS1
+ rs2[newaxis, newaxis, :] * STRIDE_XS2;
// Y pointers
TYPE* py[TS0, TS1, TS2] = Y + rs0[:, newaxis, newaxis] * STRIDE_YS0
+ rs1[newaxis, :, newaxis] * STRIDE_YS1
+ rs2[newaxis, newaxis, :] * STRIDE_YS2;
*py = *px;
}
)";
const char* copy_nd[] = {copy1d, copy2d, copy3d};
}
#endif

View File

@@ -1,58 +0,0 @@
namespace src {
const char *reduce1d =
R"(
void reduce1d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __readonly __aligned(16),
int S0) {
int pid0 = get_program_id(0);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
TYPE* px[TS0] = X + rs0;
*Y = (*px)[RED];
}
)";
const char *reduce2d =
R"(
void reduce2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0, int S1) {
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0
+ rs1[newaxis, :] * STRIDE_XS1;
TYPE* py[TY0] = Y + RY0 * STRIDE_YS0;
*py = (*px)[RED];
}
)";
const char *reduce3d =
R"(
void reduce2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0, int S1, int S2) {
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pid2 = get_program_id(2);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
int rs2[TS2] = pid2 * TS2 + 0 ... TS2;
// input pointers
TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0
+ rs1[newaxis, :, newaxis] * STRIDE_XS1
+ rs2[newaxis, newaxis, :] * STRIDE_XS2;
// output pointers
TYPE* py[TY0, TY1] = Y + RY0[:, newaxis] * STRIDE_YS0
+ RY1[newaxis, :] * STRIDE_YS1;
// write-back
*py = (*px)[RED];
}
)";
const char* reduce_nd[] = {reduce1d, reduce2d, reduce3d};
}

View File

@@ -1,225 +0,0 @@
#pragma once
#ifndef _TRITON_TESTS_UTIL_H
#define _TRITON_TESTS_UTIL_H
#include <iomanip>
#include <cmath>
#include "triton/runtime/function.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
/* ------------------------
* Launch Grid
* ------------------------ */
inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
}
inline rt::function::grid_fn_ty grid1d(size_t N) {
return [N](const rt::options_t& x) {
return rt::grid_t{ceil(N, x.D<int>("TN"))};
};
}
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
return [M, N](const rt::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
};
}
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
const std::vector<std::string>& ts) {
return [&shape, &ts](const rt::options_t& x) {
rt::grid_t ret;
for(size_t d = 0; d < shape.size(); d++)
ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
return ret;
};
}
inline std::vector<std::vector<std::string>> tile_nd(size_t rank) {
assert(rank <= 3);
if(rank == 1)
return {{"128", "256", "512", "1024"}};
if(rank == 2)
return {{"16", "32", "64"},
{"16", "32", "64"}};
if(rank == 3)
return {{"4", "16", "32"},
{"4", "16", "32"},
{"4", "16", "32"}};
return {};
}
/* ------------------------
* Tensor Initialization
* ------------------------ */
template<class T>
void init_rand(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = i;
}
template<class T>
void init_zeros(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 0;
}
template<class T>
void init_ones(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 1;
}
/* ------------------------
* Loop Nests
* ------------------------ */
void _loop_nest(std::vector<int> const & ranges,
std::function<void(std::vector<int> const &)> const & f){
int D = ranges.size();
std::vector<int> values(D, 0);
// Start with innermost loop
int i = D - 1;
while(true){
// Execute function
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
/* -----------------------
* TENSOR INDEXING
* ----------------------- */
enum order_t {
ROWMAJOR,
COLMAJOR
};
int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
int result = idx[0];
int ld = 1;
for(int i = 1; i < idx.size(); i++){
ld *= shapes[i - 1];
result += idx[i]*ld;
}
return result;
}
/* -----------------------
* REDUCTION HELPERS
* ----------------------- */
enum reduce_op_t {
ADD,
MAX,
MIN
};
std::string to_str(reduce_op_t op) {
switch (op) {
case ADD: return "+";
case MAX: return "max";
case MIN: return "min";
default: break;
}
assert(false);
return "";
}
template<class T>
std::function<T(T,T)> get_accumulator(reduce_op_t op) {
switch (op) {
case ADD: return [](T x, T y) { return x + y; };
case MAX: return [](T x, T y) { return std::max(x, y); };
case MIN: return [](T x, T y) { return std::min(x, y); };
default: break;
}
assert(false);
return std::function<T(T,T)>();
}
/* -----------------------
* TENSOR COMPARISON
* ----------------------- */
namespace testing {
template<class T>
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}
return true;
}
}
/* -----------------------
* PRETTY PRINTING
* ----------------------- */
namespace aux{
template<std::size_t...> struct seq{};
template<std::size_t N, std::size_t... Is>
struct gen_seq : gen_seq<N-1, N-1, Is...>{};
template<std::size_t... Is>
struct gen_seq<0, Is...> : seq<Is...>{};
template<class Ch, class Tr, class Tuple, std::size_t... Is>
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
using swallow = int[];
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get<Is>(t)), 0)...};
}
} // aux::
template<class Ch, class Tr, class... Args>
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
-> std::basic_ostream<Ch, Tr>&
{
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
return os;
}
template<class Ch, class Tr, class T>
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, const std::vector<T>& vec) {
os << "{";
for(size_t i = 0; i < vec.size(); i++){
if(i > 0)
os << ", ";
os << vec[i];
}
os << "}";
return os;
}
template<class Ch, class Tr>
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, reduce_op_t op) {
return os << to_str(op);
}
#endif

View File

@@ -1,6 +0,0 @@
foreach(PROG dot copy reduce)
set(TARGET unit_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
target_link_libraries(${TARGET} triton dl)
endforeach(PROG)

View File

@@ -1,60 +0,0 @@
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs;
// 1D
// configs.push_back({{65536}, {32}, {0}, {0}});
configs.push_back({{65536}, {128}, {0}, {0}});
configs.push_back({{65536}, {512}, {0}, {0}});
configs.push_back({{65536}, {1024}, {0}, {0}});
// 2D
configs.push_back({{256, 256}, {16, 16}, {0, 1}, {0, 1}});
configs.push_back({{256, 256}, {16, 64}, {0, 1}, {0, 1}});
configs.push_back({{256, 256}, {64, 16}, {0, 1}, {0, 1}});
configs.push_back({{256, 256}, {64, 64}, {0, 1}, {0, 1}});
configs.push_back({{256, 256}, {16, 16}, {0, 1}, {1, 0}});
configs.push_back({{256, 256}, {16, 64}, {0, 1}, {1, 0}});
configs.push_back({{256, 256}, {64, 16}, {0, 1}, {1, 0}});
configs.push_back({{256, 256}, {64, 64}, {0, 1}, {1, 0}});
configs.push_back({{256, 256}, {16, 16}, {1, 0}, {0, 1}});
configs.push_back({{256, 256}, {16, 64}, {1, 0}, {0, 1}});
configs.push_back({{256, 256}, {64, 16}, {1, 0}, {0, 1}});
configs.push_back({{256, 256}, {64, 64}, {1, 0}, {0, 1}});
configs.push_back({{256, 256}, {64, 64}, {1, 0}, {1, 0}});
configs.push_back({{256, 256}, {16, 64}, {1, 0}, {1, 0}});
configs.push_back({{256, 256}, {64, 16}, {1, 0}, {1, 0}});
configs.push_back({{256, 256}, {64, 64}, {1, 0}, {1, 0}});
// 3D
std::vector<std::vector<int>> xx_idx = {{0, 1, 2}, {2, 1, 0}, {1, 0, 2}};
std::vector<std::vector<int>> yy_idx = {{0, 1, 2}, {2, 1, 0}, {1, 0, 2}};
for(const auto& x_idx: xx_idx)
for(const auto& y_idx: yy_idx){
configs.push_back({{64, 64, 32}, {16, 4, 8}, x_idx, y_idx});
configs.push_back({{64, 64, 32}, {8, 16, 2}, x_idx, y_idx});
configs.push_back({{64, 64, 32}, {32, 2, 2}, x_idx, y_idx});
configs.push_back({{64, 64, 32}, {16, 64, 4}, x_idx, y_idx});
}
// testing
std::vector<int32_t> shape, tile;
std::vector<int32_t> ord_x, ord_y;
bool result = true;
for(const auto& c: configs){
std::tie(shape, tile, ord_x, ord_y) = c;
bool pass = test_copy_nd(context, stream, FLOAT, shape, tile, ord_x, ord_y);
result = result && pass;
std::cout << "// " << c << ", " << pass << std::endl;
}
return result;
}

View File

@@ -1,70 +0,0 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "dot.h"
#include "util.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to test
typedef std::tuple<dtype_t, bool, bool, int, int, int, int> config_t;
std::vector<config_t> configs;
for(dtype_t dtype: std::vector<dtype_t>{FLOAT, HALF})
for(bool AT: std::vector<bool>{false, true})
for(bool BT: std::vector<bool>{false, true}){
// 1 warp
configs.push_back({dtype, AT, BT, 16, 16, 16, 1});
configs.push_back({dtype, AT, BT, 32, 16, 16, 1});
configs.push_back({dtype, AT, BT, 16, 32, 16, 1});
configs.push_back({dtype, AT, BT, 16, 16, 32, 1});
configs.push_back({dtype, AT, BT, 32, 16, 32, 1});
configs.push_back({dtype, AT, BT, 16, 32, 32, 1});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 16, 64, 16, 1});
configs.push_back({dtype, AT, BT, 16, 16, 64, 1});
configs.push_back({dtype, AT, BT, 64, 16, 64, 1});
configs.push_back({dtype, AT, BT, 16, 64, 64, 1});
}
// 2 warps
configs.push_back({dtype, AT, BT, 64, 32, 64, 2});
configs.push_back({dtype, AT, BT, 32, 64, 64, 2});
configs.push_back({dtype, AT, BT, 64, 32, 16, 2});
configs.push_back({dtype, AT, BT, 32, 64, 16, 2});
configs.push_back({dtype, AT, BT, 128, 32, 32, 2});
configs.push_back({dtype, AT, BT, 32, 128, 32, 2});
// 4 warps
configs.push_back({dtype, AT, BT, 128, 64, 16, 4});
configs.push_back({dtype, AT, BT, 64, 128, 16, 4});
configs.push_back({dtype, AT, BT, 128, 32, 32, 4});
configs.push_back({dtype, AT, BT, 32, 128, 32, 4});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 128, 32, 64, 4});
configs.push_back({dtype, AT, BT, 32, 128, 64, 4});
}
// 8 warps
configs.push_back({dtype, AT, BT, 128, 256, 16, 8});
configs.push_back({dtype, AT, BT, 256, 128, 16, 8});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
}
};
// test
dtype_t dtype;
bool AT, BT;
int M, N, K, TM, TN, TK, nwarp;
for(const auto& c: configs){
std::tie(dtype, AT, BT, TM, TN, TK, nwarp) = c;
M = TM;
N = TN;
K = TK;
std::cout << "Testing " << c << " ... " << std::flush;
if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp))
std::cout << " Pass! " << std::endl;
else{
std::cout << " Fail! " << std::endl;
}
}
}

View File

@@ -1,45 +0,0 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include <functional>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "cuda/cublas.h"
#include "reduce.h"
#include "util.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
std::vector<config_t> configs = {
config_t{{64}, 0, ADD},
config_t{{128}, 0, MIN},
config_t{{32, 32}, 0, MAX},
config_t{{32, 32}, 1, ADD},
config_t{{32, 64}, 0, ADD},
config_t{{64, 32}, 1, ADD},
config_t{{8, 8, 4}, 2, ADD},
config_t{{8, 8, 4}, 0, ADD},
config_t{{8, 8, 4}, 1, ADD}
};
// does the work
int axis;
std::vector<int> shape;
reduce_op_t op;
for(const auto& c: configs){
std::tie(shape, axis, op) = c;
std::cout << "Testing " << c << " ... " << std::flush;
if(do_test(context, stream, shape, axis, op, 1))
std::cout << " Pass! " << std::endl;
else
std::cout << " Fail! " << std::endl;
}
}

245
tutorials/01-matmul.cc Normal file
View File

@@ -0,0 +1,245 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include <tuple>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include <iomanip>
#include <cmath>
#include "triton/runtime/function.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
namespace src {
const char *dot =
R"(
#define STM 8
#define STN 8
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
int* locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
int width = STM*gridn;
int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM);
int stn = (pid % width) / (RSTM*STN);
int RSTN = min(gridn - stn*STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M &&
rcn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";
}
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
const std::vector<int>& a_order, const std::vector<int>& b_order,
std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::device* device = context->device();
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
int32_t ldc = N;
std::vector<std::string> sa = { "1", "lda" };
std::vector<std::string> sb = { "1", "ldb" };
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
// initialize buffers
std::vector<T> hc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
// copy buffer
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// macros
rt::options_space_t opts;
// A access patterns
opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
// B access patterns
opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
// data-type
opts.defines.push_back({"TYPE", {ty}});
// tile sizes
opts.defines.push_back({"TM", {"128"}});
opts.defines.push_back({"TN", {"128"}});
opts.defines.push_back({"TK", {"32"}});
opts.defines.push_back({"TZ", {"1"}});
opts.num_warps = {4};
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
rt::add_arg(oss, *db->cu());
rt::add_arg(oss, *dc->cu());
rt::add_arg(oss, (float)1);
rt::add_arg(oss, M);
rt::add_arg(oss, N);
rt::add_arg(oss, K);
rt::add_arg(oss, lda);
rt::add_arg(oss, ldb);
rt::add_arg(oss, ldc);
rt::add_arg(oss, *dlocks->cu());
// kernel
rt::function function(src::dot, opts, device);
// grid
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [ceil, M, N](const rt::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM"))*
ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
// metrics
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
}
std::vector<double> bench_dot(drv::context* context, drv::stream* stream,
dtype_t dtype, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
const std::vector<int>& a_order, const std::vector<int>& b_order) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF: triton_dot<half_float::half>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
case FLOAT: triton_dot<float>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
case DOUBLE: triton_dot<double>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
default: break;
}
return bench;
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to benchmark
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs = {
{{1, 0}, false, false, 8192, 8192, 8192}
};
// does the work
std::vector<int> ord;
bool AT, BT;
int32_t M, N, K;
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K ;
for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -1,5 +1,5 @@
foreach(PROG dot copy conv) foreach(PROG 01-matmul)
set(TARGET bench_${PROG}) set(TARGET ${PROG})
add_executable(${TARGET} ${PROG}.cc) add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
target_link_libraries(${TARGET} triton dl) target_link_libraries(${TARGET} triton dl)