44 Commits
v0.1 ... v0.2.3

Author SHA1 Message Date
Philippe Tillet
526a3468e0 [Version] Now version 0.2.3 2020-09-03 14:36:30 -04:00
Philippe Tillet
5700260e5e Merge pull request #41 from jeffra/fix-conda-build
fix llvm build inside conda environment
2020-09-03 14:28:33 -04:00
Jeff Rasley
91d21134a7 fix llvm build inside conda environment (see link for similar issue)
https://github.com/tensorflow/tensorflow/issues/12998
2020-09-03 11:26:51 -07:00
Philippe Tillet
010f85e91d [VERSION] Now version 0.2.2 2020-08-31 15:16:09 -04:00
Philippe Tillet
3195aca452 [PYTHON] Added credits to Scott Gray for the idea used in launch.cc 2020-08-12 19:39:57 -04:00
Philippe Tillet
d5610b685b [TESTS] Updated the test to be compatible with the new runtime API 2020-08-12 19:37:08 -04:00
Philippe Tillet
30ac1359b9 [RUNTIME] Lower-level interface for executing functions 2020-08-12 18:33:35 -04:00
Philippe Tillet
8e5744f171 [EXAMPLES] Added C++ example for Conv2d 2020-08-10 11:53:34 -04:00
Philippe Tillet
2d6484482f [CODEGEN][ANALYSIS] Fixed issue in layout inference 2020-08-10 11:53:11 -04:00
Philippe Tillet
bd2067606c [EXAMPLES] Improved mat_mul example 2020-08-06 17:29:52 -04:00
Philippe Tillet
f01bdd6207 [EXAMPLES] Added conv2d example 2020-08-06 17:29:52 -04:00
Philippe Tillet
ec923cfc5e [PACKAGING] Now version 0.2.1 2020-07-29 14:07:28 -04:00
Philippe Tillet
cfcda09679 [DRIVER] Now always using PTXv6.4 2020-07-29 13:50:01 -04:00
Philippe Tillet
3683b38963 [PACKAGING] Now version 0.2.0 2020-07-14 19:50:39 -04:00
Philippe Tillet
cc7c77246b [EXAMPLES][TUTORIAL] Changed to new triton.kernel API 2020-07-08 13:39:19 -04:00
Philippe Tillet
baabaf0d58 [TRITON][PYTHON] Reverted back to distutils 2020-06-24 16:39:49 -04:00
Philippe Tillet
8e3ec0556f [TRITON][KERNEL] Fixed issue for concurrent compilation of torch
extensions
2020-06-24 15:47:43 -04:00
Philippe Tillet
7c519e09c5 [ANALYSIS] Replaced min by gcd in layout inference 2020-06-06 17:42:36 -04:00
Philippe Tillet
8e9d793d11 [CODEGEN] Fixed various issues in alignment inference pass 2020-06-06 11:28:43 -04:00
Philippe Tillet
547434d7f0 [CODEGEN] Fixed bug in alignment inference that prevented vectorization
in some cases
2020-06-06 01:13:38 -04:00
Philippe Tillet
c1f6ce68f9 [README] Changed requirement to LLVM-9 2020-06-05 12:05:21 -04:00
Philippe Tillet
b48460f51c [PACKAGING] Now version 0.1.3 2020-06-05 12:01:09 -04:00
Philippe Tillet
52eca7676b [TRITON] Fixed misaligned address issue 2020-06-05 12:00:42 -04:00
Philippe Tillet
a92031b214 [CODEGEN] Removed unnecessary coalescing rematerialization 2020-05-30 15:03:48 -04:00
Philippe Tillet
18a4cefec7 [CORE] Auto-tuning now copies scalar buffers. Still needs to copy all buffers that are both read from and written to. 2020-05-15 23:21:42 -04:00
Philippe Tillet
735215d9c4 [PYTHON] Added support for FP16 scalar kernel arguments 2020-05-15 16:37:25 -04:00
Philippe Tillet
05214d22e3 [CODEGEN] Bugfix in Disassociate pass; Added fp32 atomic_add support 2020-05-13 23:21:21 -04:00
Philippe Tillet
d390d941d8 [LANG] Added support for flattening 2020-05-13 20:44:09 -04:00
Philippe Tillet
f8614b2490 [PACKAGING] Now version 0.1.2 2020-05-12 20:50:52 -04:00
Philippe Tillet
8f9233e546 [LANG] Fixed undefined behavior in replace_all_uses_with() 2020-05-12 20:31:10 -04:00
Philippe Tillet
cd21151b98 [GENERAL] Fixed some undefined behavior with GCC-9 2020-05-11 11:07:21 -04:00
Philippe Tillet
9da8fe11ed [CODEGEN] Fixed bug that caused missing recoalescing for some transpose
operations
2020-05-11 00:26:03 -04:00
Philippe Tillet
65a41af8dc [README] Improved wording 2020-05-10 00:35:41 -04:00
Philippe Tillet
142a1dd42b [PYTHON] Restored compatibility with powerpc 2020-05-09 23:46:25 -04:00
Philippe Tillet
f88a7a0027 [PYTHON] Removed -std=gnu++11 in extra_cflags 2020-05-09 22:57:40 -04:00
Philippe Tillet
bfa85ef030 [GENERAL] Added compatibility with pytorch 1.2.0 and powerpc 2020-05-09 21:58:57 -04:00
Philippe Tillet
3848d25450 [DOCS] Added pip command in README.md 2020-05-07 10:48:41 -04:00
Philippe Tillet
1a20099582 [DOCS] Now specifying pip command in installation.rst 2020-05-07 10:45:58 -04:00
Phillippe Tillet
fc3747d59c Merge pull request #38 from jack-willturner/master
Add working examples to tutorials and python examples folder
2020-05-07 10:24:51 -04:00
jack-willturner
5fddc2062e [DOCS] Transposition fix 2020-05-07 14:02:18 +01:00
jack-willturner
91bfb0d182 Merge https://github.com/ptillet/triton 2020-05-07 13:29:44 +01:00
jack-willturner
be02315168 [DOCS] Matrix copy and transpose 2020-05-05 14:30:49 +01:00
Philippe Tillet
2dd3fcb270 [CORE] Fixed bug for Multi-GPU 2020-05-04 18:36:44 -04:00
jack-willturner
f5d47536c5 [DOCS] Matmul and vecadd working examples 2020-05-04 16:25:17 +01:00
52 changed files with 1487 additions and 1086 deletions

View File

@@ -21,7 +21,7 @@ endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++11")
# Tests
if(BUILD_TESTS)
@@ -33,8 +33,10 @@ endif()
if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# PyBind11 wrapper source file
set(PYTHON_SRC bindings.cc)
set(PYTHON_SRC bindings.cc launch.cc)
set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0")
include_directories("." ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS})
endif()
@@ -43,3 +45,6 @@ file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton ${LLVM_LIBRARIES})
if(BUILD_PYTHON_MODULE)
target_link_libraries(triton ${TORCH_LIBRARIES})
endif()

View File

@@ -1,29 +1,41 @@
# Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with much higher flexibility than [TVM](https://github.com/apache/incubator-tvm) and without having to manually specify compute schedules.
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The main components of Triton at the moment are:
- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level (think numpy-like array operations in a C-like language).
- **Triton-IR**: A special-purpose intermediate representation (Triton-IR) for aiding array-level program analysis and optimizations in Triton-C programs.
- **Triton-JIT**: An optimizing just-in-time compiler for Triton-IR, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)). This includes transparent support for mixed-precision and Tensor Cores.
Bindings for **automatic** PyTorch custom op generations are included in **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc.
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please cite us if you use our work!
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
## Installation
Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM-8.0+ for code generation.
You can install the latest release with pip as follows:
```
sudo apt-get install llvm-9-dev
pip install triton
```
or the latest development version with:
```
sudo apt-get install llvm-8-dev
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
```
for the C++ package:
```
git clone https://github.com/ptillet/triton.git;
mkdir build;
cd build;
cmake ../;
make -j8;
```
## Getting Started
@@ -32,8 +44,4 @@ Please visit the [documentation](https://docs.triton-lang.org) to get started wi
## Contributing
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!
## ISAAC (deprecated) for fast inference
Before working on Triton, I wrote custom auto-tuned PTX code for fast, quantized inference on GPUs. While this project is now deprecated, you can use it at your own risk by checking out the "isaac" tag in this repository.
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!

View File

@@ -1,6 +1,15 @@
Installation
============
Triton can be installed directly from pip with the following command
.. code-block:: python
pip install triton
See the information below for more detailed information on custom builds.
.. toctree::
:maxdepth: 1

View File

@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
"""
# create callable kernel for the source-code
# options: 4 warps and a -DTILE=1024
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
# Forward pass
@staticmethod
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
zb = add(x, y)
diff = (za - zb).abs().max()
print(diff)
print(torch.allclose(za,zb))
Executing the above code will:
@@ -97,3 +98,5 @@ Executing the above code will:
- Call the resulting custom op
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.

View File

@@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu
Compute Kernel
==============
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
.. code-block:: C
@@ -24,7 +24,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
int rm[TM] = pm * TM + 0 ... TM; //(3)
int rn[TN] = pn * TN + 0 ... TN; //(4)
int rk[TK] = 0 ... TK; //(5)
// initialize accumulator
// initialize accumulator
float c[TM, TN] = 0; //(6)
// pointers to operands
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
@@ -32,10 +32,10 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
// reduction loop
for(int k = K; k > 0; k-= TK){
// fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += dot(a, b); //(11)
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
@@ -85,10 +85,10 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
.. code-block:: C
// pre-fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += dot(a, b);
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
@@ -113,7 +113,7 @@ Rematerialization
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
*pc = c;
------------------------------------
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += dot(USE_A, USE_B);
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
@@ -163,16 +163,16 @@ All matrix multiplications variants can then be retrieved using the following co
.. code-block:: C
// A is not transposed
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
// A is transposed
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
// B is not transpose
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
// B is transpose
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.

View File

@@ -3,7 +3,7 @@ Matrix Transpositions
*********************
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
@@ -16,7 +16,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
.. code-block:: C
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
__global__ void transpose(TYPE * X, TYPE * Y,
__global__ void transpose(TYPE * X, TYPE * Y,
int M, int N, int ldx, int ldy) {
// extract program ID
int pidm = get_program_id(0); //(1)
@@ -25,12 +25,12 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
// write back using the transposition operator '^'
*py = ^(*px); //(7)
}
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
@@ -54,7 +54,7 @@ which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
::
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
@@ -73,6 +73,65 @@ which will be used in statements (5) and (6) to construct tiles of pointers
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
==================================
A Note on Numpy-style Broadcasting
==================================
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
.. code-block:: python
import numpy as np
vec = np.linspace(0,11,12)
mat = vec.reshape((4,3))
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
::
[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
::
[ 0, 1, 2],
[ 3, 4, 5],
We construct ranges `rm` and `rn` as:
::
rm = [0, 1]
rn = [0, 1, 2]
Using numpy-style broadcasting, we can add these together to create a matrix:
::
rm[:, np.newaxis] + rn[np.newaxis, :]
rn -> [0, 1, 2]
rm -> [0., [[0, 1, 2],
1.] [1, 2, 3]]
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
In this case this is 3, so what we really need is:
::
ldx = 3
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
==========================
The __multipleof attribute
==========================
@@ -83,19 +142,19 @@ For this reason, Triton provides a __multipleof(N) attributes for variables that
.. code-block:: C
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
int ldy __multipleof(8)) {
// ...
}
==========================
Bounds Checking
==========================
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
.. code-block:: C
@@ -108,6 +167,8 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.

View File

@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
'TYPE' : dtype,
'AT' : transpose_a,
'BT' : transpose_b,
'TM' : [32, 64, 128]
'TN' : [32, 64, 128]
'TK' : [8]
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [8],
# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
@@ -177,4 +177,4 @@ A callable operation can be created using the :code:`apply` method of the :code:
dot = _dot.apply
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~

View File

@@ -11,6 +11,7 @@ namespace ir {
class module;
class phi_node;
class splat_inst;
class cast_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
@@ -44,6 +45,7 @@ private:
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
@@ -53,6 +55,7 @@ private:
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps

View File

@@ -19,7 +19,7 @@ namespace transform{
class peephole {
private:
bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);

View File

@@ -32,7 +32,7 @@ public:
driver::context* context() const;
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL) = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
@@ -53,7 +53,7 @@ public:
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -66,7 +66,7 @@ public:
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -80,7 +80,7 @@ public:
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};

View File

@@ -39,7 +39,9 @@ enum attribute_kind_t {
writeonly,
noalias,
aligned,
multiple_of
multiple_of,
retune,
not_implemented
};
class attribute {
@@ -112,6 +114,7 @@ public:
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
// visitor

View File

@@ -20,15 +20,18 @@ class visitor;
//===----------------------------------------------------------------------===//
class value {
public:
typedef std::set<user*> users_t;
public:
// constructor
value(type *ty, const std::string &name = "");
virtual ~value(){ }
// uses
void add_use(user* arg);
unsigned erase_use(user* arg);
users_t::iterator erase_use(user* arg);
const std::set<user*> &get_users() { return users_; }
virtual void replace_all_uses_with(value *target);
void replace_all_uses_with(value *target);
// name
void set_name(const std::string &name);
const std::string &get_name() const { return name_; }
@@ -41,7 +44,7 @@ private:
protected:
type *ty_;
std::set<user*> users_;
users_t users_;
};
//===----------------------------------------------------------------------===//
@@ -63,6 +66,7 @@ public:
user(type *ty, unsigned num_ops, const std::string &name = "")
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
}
virtual ~user() { }
// Operands
const ops_t& ops() { return ops_; }
@@ -74,8 +78,7 @@ public:
unsigned get_num_hidden() const;
// Utils
void replace_all_uses_with(value *target);
void replace_uses_of_with(value *before, value *after);
value::users_t::iterator replace_uses_of_with(value *before, value *after);
private:

View File

@@ -64,7 +64,8 @@ public:
ALIGNED,
NOALIAS,
READONLY,
WRITEONLY
WRITEONLY,
RETUNE,
};
KindT kind;

View File

@@ -53,13 +53,14 @@ inline bool is_int_type(arg_type ty){
}
class arg {
private:
public:
union value_t {
bool int1;
int8_t int8;
int16_t int16;
int32_t int32;
int64_t int64;
uint16_t fp16;
float fp32;
double fp64;
driver::buffer* buf;
@@ -67,6 +68,7 @@ private:
public:
// construct from primitive types
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }

View File

@@ -3,7 +3,7 @@
#ifndef _TRITON_RUNTIME_FUNCTION_H_
#define _TRITON_RUNTIME_FUNCTION_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
@@ -62,6 +62,7 @@ public:
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
std::vector<int> recompile_key;
};
struct options_t {
@@ -94,19 +95,25 @@ private:
// accessors
const options_t opt() const { return opt_; }
const driver::module* parent() const { return &*parent_; }
const driver::kernel* bin() const { return &*bin_; }
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
const std::vector<arg_type>& param_tys() const { return param_tys_; }
std::vector<int> retune() const { return retune_; }
// entry points
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
private:
std::shared_ptr<driver::kernel> bin_;
std::shared_ptr<driver::module> parent_;
std::vector<arg_type> param_tys_;
std::vector<int> retune_;
options_t opt_;
std::string name_;
};
private:
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
private:
// cache
@@ -118,16 +125,15 @@ private:
caller *make(driver::stream *stream, options_t opt);
void precompile(driver::stream *stream, const options_space_t& tuning_space);
// autotune
function::cache_key_t get_key(driver::stream *stream, const std::vector<arg>& args);
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
public:
static std::string preheader();
public:
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void set_cst(const std::string& name, void* data, size_t n_bytes);
private:
@@ -138,6 +144,8 @@ private:
options_space_t opt_;
std::set<options_t> compiled_;
std::map<options_t, std::unique_ptr<caller>> callers_;
std::vector<int> args_off_;
size_t args_size_;
// caching
std::string cache_ref_;
std::string cache_path_;

View File

@@ -38,19 +38,22 @@ inline double bench(std::function<void()> const & op, driver::stream * stream, b
double total_time = 0;
op();
stream->synchronize();
while(total_time*1e-9 < 1e-2){
float norm = 1;
tmr.start();
for(size_t i = 0; i < 10; i++){
// while(total_time*1e-9 < 1e-2){
// float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(normalize)
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
// times.push_back(norm*tmr.get().count());
// total_time+=times.back();
}
return *std::min_element(times.begin(), times.end());
stream->synchronize();
return (float)tmr.get().count() / 10;
// return *std::min_element(times.begin(), times.end());
}
}

View File

@@ -287,10 +287,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst > 0)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst > 0)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
@@ -329,9 +327,16 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
auto result = populate_max_contiguous(v->get_operand(0));
return add_to_cache(v, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
@@ -436,16 +441,16 @@ std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
}
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
@@ -466,6 +471,13 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
@@ -510,6 +522,11 @@ void align::populate(ir::value *v) {
void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}

View File

@@ -22,6 +22,30 @@ namespace codegen{
using namespace llvm;
// Function for extended Euclidean Algorithm
inline int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
inline int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
using llop = llvm::Instruction::BinaryOps;
@@ -153,6 +177,7 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly);
case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly);
case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value());
case ir::retune: return llvm::Attribute::get(ctx, llvm::Attribute::None);
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
}
}
@@ -309,7 +334,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
unsigned contiguous = 1;
if(ld < x->get_type()->get_tile_rank())
contiguous = result->axis(ld).contiguous;
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
unsigned vector_size = gcd(contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
@@ -329,7 +354,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
unsigned contiguous = 1;
if(ld < x->get_type()->get_tile_rank())
contiguous = result->axis(ld).contiguous;
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
unsigned vector_size = gcd(contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
@@ -347,7 +372,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
std::map<unsigned, Value*> packets;
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
@@ -400,7 +425,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// extract result element
for_each(x, [&](indices_t idx){
distributed_tile* result = (distributed_tile*)tmap_.at(x);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
unsigned linear = result->get_linear_index(idx);
unsigned id = linear / vector_size;
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
@@ -418,6 +443,8 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
});
}
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand());
@@ -425,7 +452,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
int vector_size = 1;
int ld = ptrs->get_order()[0];
unsigned alignment = alignment_->get(st->get_pointer_operand(), ld);
vector_size = std::min<unsigned>(ptrs->axis(ld).contiguous, alignment);
vector_size = gcd(ptrs->axis(ld).contiguous, alignment);
// create packets
std::map<unsigned, Value*> packets;
ir::value *arg = st->get_value_operand();
@@ -556,14 +583,14 @@ void generator::visit_exp_inst(ir::exp_inst* x){
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
for_each(x, [&](indices_t idx){
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
set_value(x, idx, builder_->CreateCall(ex2, std::vector<llvm::Value*>{ex2arg}));
});
}
@@ -584,7 +611,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
AtomicOrdering::Monotonic,
AtomicOrdering::Monotonic);
old = builder_->CreateExtractValue(old, {0});
old = builder_->CreateExtractValue(old, std::vector<unsigned>{0});
Value *atom_ptr;
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
@@ -618,8 +645,25 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
tgt_->add_memfence(module, *builder_);
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
throw std::runtime_error("unsupported");
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = vmap_.at(add->get_operand(0));
Value *rmw_val = vmap_.at(add->get_operand(1));
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
}
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
@@ -640,8 +684,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
Type *fp32_ty = builder_->getFloatTy();
Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
@@ -720,15 +764,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
fc[idx[1]] = builder_->CreateExtractValue(nc, {1});
fc[idx[2]] = builder_->CreateExtractValue(nc, {2});
fc[idx[3]] = builder_->CreateExtractValue(nc, {3});
fc[idx[4]] = builder_->CreateExtractValue(nc, {4});
fc[idx[5]] = builder_->CreateExtractValue(nc, {5});
fc[idx[6]] = builder_->CreateExtractValue(nc, {6});
fc[idx[7]] = builder_->CreateExtractValue(nc, {7});
Value *nc = builder_->CreateCall(mma_fn, std::vector<llvm::Value*>{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{0});
fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{1});
fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{2});
fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{3});
fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{4});
fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{5});
fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{6});
fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{7});
}
}
}
@@ -770,7 +814,7 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_ti
a = builder_->CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
}
set_value(dot, idx, res);
});
@@ -790,7 +834,7 @@ void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distrib
a = builder_->CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
set_value(dot, idx, res);
});
}
@@ -805,7 +849,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
@@ -835,8 +879,8 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
for_each(sqt, [&](indices_t idx){
Value *val = get_value(sqt->get_operand(0), idx);
Module* module = builder_->GetInsertBlock()->getModule();
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()});
Value *ret = builder_->CreateCall(sqrt, {val});
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector<llvm::Type*>{val->getType()});
Value *ret = builder_->CreateCall(sqrt, std::vector<llvm::Value*>{val});
set_value(sqt, idx, ret);
});
}
@@ -849,7 +893,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
unsigned axis = x->get_axis();
Type *fp32_ty = builder_->getFloatTy();
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty}, false);
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
@@ -871,8 +915,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
}
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector<llvm::Value*>{x, y});
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector<llvm::Value*>{x, y});
default: assert(false); return nullptr;
}
};
@@ -910,11 +954,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
thread_acc = accumulate(thread_acc, current);
});
// reduce within wrap
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector<llvm::Type*>{thread_acc->getType(), builder_->getInt32Ty()}, false);
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
Value *warp_acc = thread_acc;
for(int i = 16; i > 0; i >>= 1)
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{warp_acc, builder_->getInt32(i)}));
// shared memory pointer
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *res_ty = arg_tile->get_ty();
@@ -935,7 +979,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
builder_->SetInsertPoint(bb_final_acc);
Value* final_val = builder_->CreateLoad(load_ptr);
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{final_val, builder_->getInt32(i)}));
builder_->CreateStore(final_val, load_ptr);
builder_->CreateBr(bb_final_acc_done);
// // store first warp done

View File

@@ -109,14 +109,18 @@ void coalesce::run(ir::module &mod) {
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){
extract_ld(i, axes);
}
}
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
for(auto it = ++axes.rbegin(); it != axes.rend(); it++){
if(it->second.size() == 1)
continue;
remat.insert(remat.begin(), it->second.begin(), it->second.end());
}
}
// rematerialize values
for(ir::io_inst *r: remat) {

View File

@@ -34,10 +34,13 @@ void disassociate::run(ir::module &mod) {
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
ir::value* op = i->get_operand(0);
if(!dynamic_cast<ir::user*>(op))
return;
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
return;
std::map<int, std::set<ir::user*>> chains;
std::set<ir::value*> seen;
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
return;
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;

View File

@@ -81,20 +81,21 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
add->replace_all_uses_with(new_dot);
return true;
}
return false;
}
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
if(cfs) {
ir::value *arg = cfs->get_operand(0);
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
if(!cts)
return false;
cfs->replace_all_uses_with(cts->get_operand(0));
return true;
}
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
// if(cfs) {
// ir::value *arg = cfs->get_operand(0);
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
// if(!cts)
// return false;
// cfs->replace_all_uses_with(cts->get_operand(0));
// return true;
// }
}
//}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
@@ -196,7 +197,7 @@ void peephole::run(ir::module &mod) {
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);

View File

@@ -146,7 +146,8 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
llvm::Function* fn = src->getFunction("matmul");
llvm::FunctionType *fn_ty = fn->getFunctionType();
@@ -250,11 +251,10 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
int version;
dispatch::cuDriverGetVersion(&version);
int major = version / 1000;
int minor = (version - major*1000) / 10;
// int minor = (version - major*1000) / 10;
if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+");
if(minor >= 1)
find_and_replace(result, ".version", "\n", ".version 6.4\n");
find_and_replace(result, ".version", "\n", ".version 6.4\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
@@ -265,6 +265,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx(*context);
// std::cout << source << std::endl;
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;

View File

@@ -79,7 +79,7 @@ void host_stream::synchronize() {
}
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
@@ -112,7 +112,7 @@ void cl_stream::synchronize() {
check(dispatch::clFinish(*cl_));
}
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
}
@@ -149,12 +149,11 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_);
}
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
driver::cu_kernel* cu_kernel = (driver::cu_kernel*)kernel;
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** extra) {
cu_context::context_switcher ctx_switch(*ctx_);
if(event)
dispatch::cuEventRecord(event->cu()->first, *cu_);
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)cu_kernel->cu_params(), NULL);
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra);
if(event)
dispatch::cuEventRecord(event->cu()->second, *cu_);
}

View File

@@ -70,9 +70,10 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
std::set<ir::user*> users = phi->get_users();
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::set<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)

View File

@@ -1,4 +1,5 @@
#include <cassert>
#include <iostream>
#include "triton/ir/value.h"
#include "triton/ir/instructions.h"
@@ -19,8 +20,11 @@ void value::add_use(user *arg) {
users_.insert(arg);
}
unsigned value::erase_use(user *arg){
return users_.erase(arg);
value::users_t::iterator value::erase_use(user *arg){
auto it = users_.find(arg);
if(it == users_.end())
return it;
return users_.erase(it);
}
// TODO: automatic naming scheme + update symbol table
@@ -29,9 +33,12 @@ void value::set_name(const std::string &name){
}
void value::replace_all_uses_with(value *target){
throw std::runtime_error("not implemented");
for (auto it = users_.begin(); it != users_.end(); ) {
it = (*it)->replace_uses_of_with(this, target);
}
}
void visitor::visit_value(ir::value* v) {
v->accept(this);
}
@@ -59,18 +66,12 @@ unsigned user::get_num_hidden() const {
return num_hidden_;
}
void user::replace_all_uses_with(value *target) {
for(auto it = users_.begin(); it != users_.end(); it++){
(*it)->replace_uses_of_with(this, target);
}
}
void user::replace_uses_of_with(value *before, value *after) {
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
for(size_t i = 0; i < ops_.size(); i++)
if(ops_[i] == before)
ops_[i] = after;
after->add_use(this);
before->erase_use(this);
return before->erase_use(this);
}

View File

@@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() {
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
Error(this, "tile with more than one element cannot be casted to scalar");
if(type_->IsTile() && operandType->IsTile()){
auto shape = type_->ToTile()->Shape();
auto operandShape = operandType->ToTile()->Shape();
if(operandShape.size() > shape.size())
Error(this, "cast cannot reduce operand rank");
auto shape = type_->ToTile()->Shape();
// this is a shape downcast
if(operandShape.size() > shape.size()){
size_t operandNumel = 1;
size_t numel = 1;
for(auto x: operandShape)
operandNumel *= x;
for(auto x: shape)
numel *= x;
if(operandNumel != numel)
Error(this, "cast cannot change number of elements");
return;
}
// this is a shape upcast
while(operandShape.size() < shape.size())
operandShape.insert(operandShape.begin(), 1);
for(size_t i = 0; i < shape.size(); i++) {

View File

@@ -277,6 +277,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name == "f32_atomic_add"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_add(ptr, val));
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
@@ -623,7 +630,10 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
return ir::attribute(ir::readonly);
if(attr.kind == ASTNode::Attr::WRITEONLY)
return ir::attribute(ir::writeonly);
if(attr.kind == ASTNode::Attr::RETUNE)
return ir::attribute(ir::retune);
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
return ir::attribute(ir::not_implemented);
}
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {

View File

@@ -2778,6 +2778,8 @@ ASTNode::Attr Parser::ParseAttribute() {
ret.kind = ASTNode::Attr::MULTIPLEOF;
else if(name == "noalias")
ret.kind = ASTNode::Attr::NOALIAS;
else if(name == "retune")
ret.kind = ASTNode::Attr::RETUNE;
else
Error(tok, "unknown attribute kind");
// set exprs

View File

@@ -151,27 +151,23 @@ function::caller::caller(ir::function *ir,
bin_.reset(driver::kernel::create(&*parent, name_.c_str()));
// extract signature
ir::function_type* ty = ir->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++)
for(size_t i = 0; i < ty->get_num_params(); i++){
param_tys_.push_back(convert(ty->get_param_ty(i)));
if(!ir->has_attr(i+1))
continue;
for(ir::attribute attr: ir->attrs().at(i + 1))
if(attr.get_kind() == ir::retune)
retune_.push_back(i);
}
}
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector<arg>& args) const {
if(args.size() != param_tys_.size())
throw std::runtime_error("invalid number of arguments");
// set arguments
for(size_t i = 0; i < args.size(); i++){
arg arg_i = args.at(i);
arg_type ty = arg_i.type();
if(ty != param_tys_.at(i))
throw std::runtime_error("invalid type for argument " + std::to_string(i));
if(ty == BUFFER_T){
driver::buffer* buf = *((driver::buffer**)arg_i.data());
bin_->setArg(i, buf->size() == 0 ? nullptr : buf);
}
else
bin_->setArg(i, size_of(ty), arg_i.data());
}
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const {
void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
@@ -179,7 +175,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1});
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, config);
}
@@ -218,7 +214,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
// ir::print(module, std::cout);
dce.run(module);
disassociate.run(module);
dce.run(module);
@@ -252,20 +247,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
// create Binary from options
function::caller* function::make(driver::stream *stream, options_t opt) {
// cache path
std::string cache_path = cache_path_ + opt.to_str() + ".ptx";
int ref_mtime = tools::mtime(cache_ref_);
int ptx_mtime = tools::mtime(cache_path);
// if cached ptx is newer than reference library
if(!ref_mtime || !ptx_mtime || ref_mtime < ptx_mtime){
std::ifstream ifs(cache_path);
// file is empty -- invalid
if(ifs && ifs.peek() == std::ifstream::traits_type::eof())
return nullptr;
// load cached caller
if(ifs)
return new caller(stream->context(), ifs, opt);
}
// pre-process
TokenSequence tokens;
Preprocessor cpp(&src_, true);
@@ -282,18 +263,11 @@ function::caller* function::make(driver::stream *stream, options_t opt) {
try{
bin = make_bin(*ir, stream->context(), opt);
}catch(const std::runtime_error&){
if(!cache_path_.empty())
std::ofstream ofs(cache_path);
return nullptr;
}
// create callable
ir::function *tmp = ir->get_function_list()[0];
caller* ret = new caller(tmp, std::move(bin), opt);
// serialize callable
if(!cache_path_.empty()){
std::ofstream ofs(cache_path);
ret->write(ofs);
}
return ret;
}
@@ -331,49 +305,25 @@ void function::precompile(driver::stream* stream,
throw std::runtime_error("could not compile kernel");
}
// return auto-tuning key for given function arguments
function::cache_key_t function::get_key(driver::stream *stream, const std::vector<arg>& args) {
cache_key_t ret;
ret.first = stream->context()->device();
for(size_t i = 0; i < args.size(); i++){
arg_type ty = args.at(i).type();
if(!is_int_type(ty))
continue;
long val = 0;
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
ret.second.push_back(val);
}
return ret;
}
// returns program with best compilation options for given parameter
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
const std::vector<arg>& args) {
// // copy buffer argument so that auto-tuning doesn't corrupt data
// std::list<std::shared_ptr<driver::cu_buffer>> copies;
// std::vector<arg> args = args;
// for(arg& x: args)
// if(x.type() == BUFFER_T){
// driver::buffer* old = x.buffer();
// driver::context* ctx = old->context();
// size_t size = old->size();
// copies.push_back(std::make_shared<driver::cu_buffer>(ctx, size));
// x = arg(copies.back().get());
// }
void** args, size_t args_size) {
// fast path -- no autotuning necessary
if(callers_.size() == 1)
return &*callers_.begin()->second;
// slow path -- autotuning necessary
// TODO" copy buffer argument so that auto-tuning doesn't corrupt data
double best_ts = INFINITY;
caller* ret = nullptr;
for(auto &x : callers_){
if(x.second == nullptr)
throw std::runtime_error("configuration not compiled");
caller* current = &*x.second;
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args); },
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); },
stream, true);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
}
stream->synchronize();
return ret;
}
@@ -394,12 +344,14 @@ std::string function::preheader() {
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
#define __retune __attribute__((retune))
#define F32_INFINITY bitcast<float>(0x7F800000)
#define F16_INFINITY bitcast<half>((int16)0x7C00)
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern float f32_atomic_add(float*, float);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);
@@ -452,27 +404,35 @@ function::function(const std::string &src,
src_ = preheader() + src_;
}
void function::operator()(const std::vector<arg>& args,
const grid_fn_ty& grid_fn,
driver::stream *stream) {
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
// pre-compile kernels
if(callers_.empty())
if(callers_.empty()){
precompile(stream, opt_);
size_t cumsum = 0;
for(arg_type ty: callers_.begin()->second->param_tys()){
args_off_.push_back(cumsum);
cumsum += size_of(ty);
}
}
// re-tuning key
cache_key_t key;
key.first = stream->context()->device();
key.second = callers_.begin()->second->retune();
// auto-tune if necessary
auto key = get_key(stream, args);
auto it = cache_.find(key);
if(it == cache_.end()){
auto best = autotune(stream, grid_fn, args);
auto best = autotune(stream, grid_fn, args, args_size);
it = cache_.insert({key, best}).first;
}
// run
(*it->second)(stream, grid_fn(it->second->opt()), args);
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size);
}
void function::operator()(const std::vector<arg>& args,
void function::operator()(void** args,
size_t args_size,
const grid_t& grid,
driver::stream *stream) {
return this->operator()(args, [&grid](const options_t&){ return grid; }, stream);
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
}

View File

@@ -0,0 +1,202 @@
import torch
import triton
class _conv(torch.autograd.Function):
src = """
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M __retune,
int N __retune,
int K __retune,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs [TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr [TK] = rcir % RR;
int rci [TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz [:, newaxis] * lda_z +
rci[newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE* pa[TM, TK] = A + offa;
int* padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr [:, newaxis] * ldb_r +
rs [:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
// increment A
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta[newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = *?(checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co+
rp [:, newaxis] * ldc_p +
rq [:, newaxis] * 1;
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
"""
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride, time):
# create kernel if necessary
dtype = a.dtype
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if dtype not in _conv.kernel:
TK = 8
defines = {
'TYPE' : dtype,
'TM' : [16, 32, 64, 128],
'TN' : [16, 32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = triton.empty([Z, CO, P, Q], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))]
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta,
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid=grid, bench=100)
return c
conv = _conv.apply
torch.manual_seed(0)
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad = (1, 1)
stride = (1, 1)
a = torch.rand((Z, CI, H, W)).cuda()
b = torch.rand((CI, R, S, CO)).cuda()
time = [None]
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
c = conv(a, b, pad, stride, time)
print((cc - c).abs().max() / max(cc.max(), c.max()))
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
#zc = torch.matmul(a,b)
#zc_ = dot(a,b)

View File

@@ -0,0 +1,70 @@
import torch
import triton
class _copy(torch.autograd.Function):
src = """
__global__ void copy(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
*py = *px;
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N;
dtype = x.dtype
y = torch.empty((M,N)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _copy.kernel is None:
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
_copy.kernel(x, y, M, N, ldx, grid=grid)
return y
copy = _copy.apply
# test
torch.manual_seed(0)
x = torch.randn(8,4).cuda()
print(x)
ya = x
yb = copy(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -0,0 +1,133 @@
import torch
import triton
class _dot(torch.autograd.Function):
src = """
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
a = *?(checka)pa;
b = *?(checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
"""
@staticmethod
def forward(ctx, a, b):
c = _dot._call(a,b)
return c
kernel = dict()
@staticmethod
def _call(a, b):
# create kernel if necessary
dtype = a.dtype
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [64, 128],
'TN' : [64, 128],
'TK' : [8, 16],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
kernel = _dot.kernel[dtype]
# allocate output
M, K = a.shape
K, N = b.shape
c = triton.empty([M,N], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0),
grid=grid, bench=100)
print(2*M*N*K/(time*1e-6)*1e-9)
return c
dot = _dot.apply
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
#zc = torch.matmul(a,b)
zc_ = dot(a,b)
#print(torch.allclose(zc, zc_))

View File

@@ -0,0 +1,76 @@
import torch
import triton
class _transpose(torch.autograd.Function):
src = """
__global__ void transpose(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8), int ldy __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N
ldy = M
dtype = x.dtype
y = torch.empty((N,M)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _transpose.kernel is None:
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
return y
transpose = _transpose.apply
# test
torch.manual_seed(0)
x = torch.randn(1024,128).cuda()
print(x)
ya = torch.t(x)
yb = transpose(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -0,0 +1,43 @@
import torch
import triton
class _add(torch.autograd.Function):
src = """
__global__ void add(float* z, float* x, float* y, int N) {
int pid = get_program_id(0);
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
bool check[TILE] = offset < N;
*?(check)pz = *?(check)px + *?(check)py;
}
"""
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(98432).cuda()
y = torch.rand(98432).cuda()
za = x + y
zb = add(x, y)
print(torch.allclose(za,zb))

View File

@@ -8,9 +8,11 @@ import distutils
import glob
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from torch.utils.cpp_extension import include_paths, library_paths
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn
import torch
def find_llvm():
@@ -58,12 +60,17 @@ class CMakeBuild(build_ext):
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True)
abi = torch._C._GLIBCXX_USE_CXX11_ABI
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()]
# configuration
cfg = 'Debug' if self.debug else 'Release'
@@ -80,8 +87,6 @@ class CMakeBuild(build_ext):
build_args += ['--', '-j4']
env = os.environ.copy()
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
self.distribution.get_version())
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
@@ -100,7 +105,7 @@ for d in directories:
setup(
name='triton',
version='0.1',
version='0.2.3',
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',

View File

@@ -3,57 +3,51 @@
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
#include <regex>
#include <algorithm>
#include "triton/runtime/function.h"
#include "triton/runtime/arg.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/parser.h"
#include "triton/lang/cpp.h"
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/tools/bench.hpp"
using namespace triton;
namespace rt = triton::runtime;
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
typedef std::pair<size_t, size_t> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;
std::map<size_t, int64_t> i64scalar_map;
/* Grid map */
void register_grid(size_t id,
void register_grid(const map_key_t& key,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
}
void delete_grid(size_t id) {
id_grid_map.erase(id);
void delete_grid(const map_key_t& key) {
id_grid_map.erase(key);
}
/* Function map */
void register_fn(size_t id,
void register_fn(const map_key_t& key,
const std::string& src,
const rt::function::options_space_t& opt,
const std::string &cache_ref) {
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
}
void delete_fn(size_t id) {
id_fn_map.erase(id);
void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
}
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() {
@@ -82,196 +76,6 @@ int64_t retrieve_scalar(size_t id) {
return i64scalar_map.at(id);
}
/* TF source-code generation */
inline std::string to_tf_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
if(ty->is_integer_ty(8))
return "int8";
if(ty->is_integer_ty(16))
return "int16";
if(ty->is_integer_ty(32))
return "int32";
if(ty->is_integer_ty(64))
return "int64";
if(ty->is_half_ty())
return "float16";
if(ty->is_float_ty())
return "float";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "Tensor";
throw std::runtime_error("unknown type");
}
inline std::string to_tf_scalar_ty(ir::type *ty) {
if(ty->is_pointer_ty())
return to_tf_ty(ty->get_pointer_element_ty());
else {
return to_tf_ty(ty);
}
}
inline std::string ref_to_tf_ty(ir::type *ty) {
std::string res = to_tf_ty(ty);
if(ty->is_pointer_ty())
res = "const " + res + "&";
return res;
}
std::string tf_normalize(const std::string& name) {
std::string ret = name;
auto tolower = [](char c) { return std::tolower(c);};
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
return ret;
}
struct tf_alloc_t{
enum type_t{
OUTPUT,
TEMP
};
tf_alloc_t(const std::string& _name, type_t _type)
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
std::string tf_name;
std::string name;
type_t type;
size_t shape_id;
};
typedef std::vector<tf_alloc_t> alloc_map_t;
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
for(unsigned i = 0; i < args.size(); i++){
ir::value *arg = args[i];
const std::string& name = arg->get_name();
std::string ty = to_tf_ty(arg->get_type());
if(!arg->get_type()->is_pointer_ty())
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
else if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return x.name == name;
}) == allocs.end())
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
else
os << " Tensor* " << name << " = nullptr;\n ";
}
}
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
// initialize shapes
for(const auto& x: allocs)
os << " TensorShape " << x.name << "_shape;\n ";
for(const auto& x: allocs)
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
for(const auto& x: allocs)
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
for(const auto& x: allocs)
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
for(const auto& x: allocs)
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
// allocate
int output = 0;
for(const auto& x: allocs){
if(x.type == tf_alloc_t::OUTPUT)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
else
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
}
}
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
for(unsigned i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
if(!arg->get_type()->is_pointer_ty())
continue;
const std::string& name = arg->get_name();
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n ";
}
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
if(arg->get_type()->is_pointer_ty())
name = "&cu_" + name;
if(i > 0)
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id_), stream);\n ";
os << " };\n ";
os << " run();\n ";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args,
const alloc_map_t& allocs){
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = tf_normalize(arg->get_name());
if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")";
}
for(const auto& x: allocs)
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
os << ", " + opname << ");\n";
}
void gen_tf_register_op(std::ostream &os, const std::string &name,
const std::vector<ir::argument*>& args,
const alloc_map_t& allocs){
os << "REGISTER_OP(\"" << name << "\")\n";
for(size_t i = 0; i < args.size(); i++)
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = tf_normalize(arg->get_name());
if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return name == x.tf_name;
}) == allocs.end())
os << " .Input(\"" << name << ": T" << i << "\")\n";
else
os << " .Input(\"" << name << "_shape: int32\")\n";
}
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT)
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
size_t current = 0;
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT){
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
}
os << " return Status::OK();\n";
os << " })\n";
os << ";\n";
}
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
@@ -289,318 +93,6 @@ void make_module(const std::string& src, ir::module* ir,
gen.Gen(ir);
}
std::tuple<std::string,
std::string> make_tensorflow_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt)
{
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
const std::vector<ir::argument*>& args = fn->args();
std::string name = fn->get_name();
std::string cc_name = name;
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
std::string opname = cc_name + "Op";
// allocation info
alloc_map_t allocs;
for(size_t i = 0; i < outputs.size(); i++)
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
for(size_t i = 0; i < tmp.size(); i++)
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
for(auto &x: allocs){
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == x.name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
x.shape_id = idx;
}
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
class )" << opname << R"(: public OpKernel {
public:
explicit )" << opname << R"((OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
)";
for(const auto& alloc: allocs)
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
oss << R"(
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
drv::cu_stream sstream(device.stream(), false);
drv::context* ctx = sstream.context();
drv::stream* stream = &sstream;
// extract inputs
)";
gen_extract_inputs(oss, args, allocs);
oss << R"(
// set outputs
)";
gen_set_outputs(oss, args, allocs);
oss << R"(
// wrap tensors
)";
gen_make_handles(oss, args);
oss << R"(
)";
oss << R"(
// launch function
)";
gen_make_launch_function(oss, args);
oss << R"(
}
private:
int id_;
int bench_;
int64 bench_id_;
)";
for(const auto& alloc: allocs)
oss << "DataType " << alloc.name << "_type;\n ";
oss << R"(
};
// register kernel builder
)";
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
oss << R"(
// register op
)";
gen_tf_register_op(oss, cc_name, args, allocs);
return std::tuple<std::string, std::string>{oss.str(), name};
}
inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty())
return "int64_t";
if(ty->is_half_ty())
return "double";
if(ty->is_float_ty())
return "double";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "torch::Tensor";
throw std::runtime_error("unknown type");
}
inline std::string to_torch_ty(rt::arg_type ty){
switch(ty){
case rt::INT1_T: return "int64_t";
case rt::INT8_T: return "int64_t";
case rt::INT16_T: return "int64_t";
case rt::INT32_T: return "int64_t";
case rt::INT64_T: return "int64_t";
case rt::HALF_T: return "double";
case rt::FLOAT_T: return "double";
case rt::DOUBLE_T: return "double";
case rt::BUFFER_T: return "torch::Tensor";
default: return "UNKNOWN";
}
}
inline std::string to_c_ty(rt::arg_type ty){
switch(ty){
case rt::INT1_T: return "bool";
case rt::INT8_T: return "int8_t";
case rt::INT16_T: return "int16_t";
case rt::INT32_T: return "int32_t";
case rt::INT64_T: return "int64_t";
case rt::HALF_T: return "half";
case rt::FLOAT_T: return "float";
case rt::DOUBLE_T: return "double";
case rt::BUFFER_T: return "drv::cu_buffer";
default: return "UNKNOWN";
}
}
inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
if(ty->is_integer_ty(8))
return "int8_t";
if(ty->is_integer_ty(16))
return "int16_t";
if(ty->is_integer_ty(32))
return "int32_t";
if(ty->is_integer_ty(64))
return "int64_t";
if(ty->is_half_ty())
return "half";
if(ty->is_float_ty())
return "float";
if(ty->is_double_ty())
return "double";
if(ty->is_pointer_ty())
return "drv::cu_buffer";
throw std::runtime_error("unknown type");
}
void gen_torch_signature(std::ostringstream& oss,
const std::string& name,
const std::vector<rt::arg_type>& args) {
std::string ret_ty = "void";
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
if(i > 0)
oss << ", ";
oss << to_torch_ty(args[i]) << " " << "th_arg_" << i;
}
oss << ")";
}
void gen_torch_init_driver(std::ostringstream &oss,
const std::vector<rt::arg_type>&args) {
// Find index of first buffer
size_t i;
for(i = 0; i < args.size(); i++)
if(args[i] == rt::BUFFER_T)
break;
oss << " // Wrap CUDA handles" << std::endl;
oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl;
oss << " // Get stream" << std::endl;
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
}
void gen_torch_make_handles(std::ostream &os,
const std::vector<rt::arg_type>& args) {
for(unsigned i = 0; i < args.size(); i++){
rt::arg_type arg = args[i];
const std::string th_name = "th_arg_" + std::to_string(i);
const std::string name = "arg_" + std::to_string(i);
if(arg != rt::BUFFER_T)
os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl;
else{
os << " CHECK_INPUT(" << th_name << ");" << std::endl;
os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), "
" (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl;
}
}
}
void gen_torch_make_launch_function(std::ostream &os,
const std::vector<rt::arg_type>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
std::string name = "arg_" + std::to_string(i);
if(args[i] == rt::BUFFER_T)
name = "&" + name;
if(i > 0)
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id), &stream);\n";
os << " };\n";
os << " run();\n";
os << " if(bench > 0)\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
if(outputs.size() == 1){
os << " return " << outputs[0] << ";" << std::endl;
return;
}
os << " return {";
for(size_t i = 0; i < outputs.size(); i++){
if(i > 0)
os << ", ";
os << outputs[i];
}
os << "};" << std::endl;
}
std::tuple<std::string,
std::string> make_torch_src(const std::string& name, std::vector<rt::arg_type> args) {
// generate framework code
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x);
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
)";
gen_torch_signature(oss, name, args);
oss << " {" << std::endl;
gen_torch_init_driver(oss, args);
gen_torch_make_handles(oss, args);
gen_torch_make_launch_function(oss, args);
//gen_torch_ret(oss);
oss << "}" << std::endl;
oss << std::endl;
oss << std::endl;
oss << "static auto registry = torch::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
return std::tuple<std::string, std::string>{oss.str(), name};
}
/* Function signature */
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) {
@@ -624,13 +116,6 @@ typedef triton::runtime::function::options_space_t options_space_t;
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
// framework binding source code generation
m.def("make_tensorflow_src", &make_tensorflow_src,
"Creates C++ source code for a custom Tensorflow op "
"corresponding to the specified Triton kernel");
m.def("make_torch_src", &make_torch_src,
"Creates C++ source code for a custom PyTorch op ");
// bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)

30
python/src/launch.cc Normal file
View File

@@ -0,0 +1,30 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x);
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<size_t, size_t> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
triton::driver::cu_stream stream(custream, false);
triton::driver::context* ctx = stream.context();
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
}
static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel);

View File

@@ -1,7 +1,6 @@
from .kernel import *
from .utils import *
import triton.ops
import triton.nn
#import triton.ops
#import triton.nn
# clean-up libtriton resources

View File

@@ -1,28 +0,0 @@
import sys
import os
import triton._C.libtriton as libtriton
torch = None
tensorflow = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
def has_tensorflow():
result = 'tensorflow' in sys.modules
if result:
_import_tensorflow()
return result
def has_torch():
result = 'torch' in sys.modules
if result:
_import_torch()
return result

View File

@@ -1,199 +1,77 @@
# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
import weakref
import contextlib
import io
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
import os
import time
from struct import pack
import torch
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
codes = {
libtriton.arg_type.int1: 'B',
libtriton.arg_type.int8: 'B',
libtriton.arg_type.int32: 'I',
libtriton.arg_type.int64: 'Q',
libtriton.arg_type.half: 'H',
libtriton.arg_type.float: 'f',
libtriton.arg_type.double: 'd',
libtriton.arg_type.buffer: 'P'
}
def _build(src, path, name):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
# include directories
triton_include_dirs = [os.path.join(ccdir, 'include')]
include_dirs = triton_include_dirs
# library directories
triton_library_dirs = [ccdir]
library_dirs = triton_library_dirs
# libraries
libraries = ['triton']
# add framework
extra_compile_args = []
if fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/',
os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
else:
assert False
# extra arguments
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = name,
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args + ['-g0'],
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
depends = depends
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = name,
ext_modules = [ext],
script_args = args,
)
with quiet():
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# torch type
if fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
fw.torch.int32: 'int',
fw.torch.int64: 'long',
fw.torch.float16: 'half',
fw.torch.float32: 'float',
fw.torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
sizes = {
libtriton.arg_type.int1: 1,
libtriton.arg_type.int8: 1,
libtriton.arg_type.int32: 4,
libtriton.arg_type.int64: 8,
libtriton.arg_type.half: 2,
libtriton.arg_type.float: 4,
libtriton.arg_type.double: 8,
libtriton.arg_type.buffer: 8
}
def _encode(arg_types):
codes = {
libtriton.arg_type.int1: 'i1',
libtriton.arg_type.int8: 'i8',
libtriton.arg_type.int32: 'i32',
libtriton.arg_type.int64: 'i64',
libtriton.arg_type.half: 'f16',
libtriton.arg_type.float: 'f32',
libtriton.arg_type.double: 'f64',
libtriton.arg_type.buffer: 'buf'
def th_to_triton(obj):
tys = {
torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'
}
ret = '_'.join(map(codes.get, arg_types))
return ret
if isinstance(obj, torch.dtype):
return [tys[obj]]
if isinstance(obj, list):
return [th_to_triton(x)[0] for x in obj]
return [str(obj)]
def _make_framework_op(arg_types):
name = _encode(arg_types)
# path of .cpp and .so file
home = os.path.expanduser('~')
root = os.path.join(home, '.triton', 'torch', name)
try:
os.makedirs(root)
except FileExistsError:
pass
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, f'op{suffix}')
cpp = os.path.join(root, f'op.cpp')
# handle cached .so file
if os.path.exists(so) and os.stat(so).st_size > 0:
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
so_mtime = os.stat(so).st_mtime
# can use cached if libtriton is older than the .so
if tt_mtime < so_mtime:
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
# create torch source code
lock = os.path.join(root, f'lock')
try:
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
if os.path.exists(so):
fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name)
src, _ = libtriton.make_torch_src(name, arg_types)
with open(cpp, 'w+') as handle:
handle.writelines(src)
# create torch.so
_build(cpp, root, 'op')
fw.torch.ops.load_library(so)
os.remove(lock)
return getattr(fw.torch.ops.triton, name)
except FileExistsError:
# spin until .so is fully written
while os.path.exists(lock):
time.sleep(0.01)
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
def cdiv(a, b):
return (a + b - 1) // b
class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
self.src = src
# create constants
self.cst = dict()
# create triton op
macros = []
for k, v in defines.items():
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
macros.append((k, values))
opt = libtriton.options_space()
opt.defines = macros
opt.num_warps = num_warps
self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
self.opt.num_warps = num_warps
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__))
# create pytorch hook
arg_types = libtriton.get_fn_signature(self.src, opt)
self.fw_op = _make_framework_op(arg_types)
self.registered = set()
arg_types = libtriton.get_fn_signature(self.src, self.opt)
size = sum([sizes[x] for x in arg_types])
self.tys = ''.join([codes[x] for x in arg_types])
def set_constant(self, name, value):
libtriton.register_cst(self.op_id, name, value)
def set_constant(self, device, name, value):
libtriton.register_cst((self.op_id, device), name, value)
def __call__(self, *args, **kwargs):
for x in args:
if isinstance(x, torch.Tensor):
device = x.device.index
break
# lazily register function for device
if device not in self.registered:
self.registered.add(device)
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
# launch options
bench = kwargs['bench'] if 'bench' in kwargs else 0
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
@@ -201,8 +79,8 @@ class kernel:
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
libtriton.register_grid(self.op_id, grid)
libtriton.register_grid((self.op_id, device), grid)
# launch
self.fw_op(self.op_id, bench, bench_id, *args)
if bench > 0:
return libtriton.retrieve_scalar(bench_id)
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
torch.cuda.synchronize()
torch.ops.triton.launch_kernel(self.op_id, device, params)

View File

@@ -1,75 +0,0 @@
import triton.frameworks as fw
import triton._C.libtriton as libtriton
import numpy as np
import weakref
def cdiv(a, b):
return (a + b - 1) // b
class tf_empty_proxy:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
self.tensor = None
def to_tensor(self):
assert self.tensor is not None
return self.tensor
def empty(shape, dtype):
if fw.has_tensorflow():
shape = [fw.tensorflow.constant(x) for x in shape]
shape = fw.tensorflow.stack(shape)
return tf_empty_proxy(shape, dtype)
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch():
return fw.torch.empty(shape, dtype=dtype, device='cuda:0')
def shape(A) :
if fw.has_tensorflow():
return A.shape.as_list()
elif fw.has_torch():
return A.shape
else:
assert False
class id_dict:
# Lazy entry for e.g., tensorflow, when value of benchmark is
# not known at graph compile time
class lazy_entry:
def __init__(self, id):
self.id = id
def get(self):
return libtriton.retrieve_scalar(self.id)
def __init__(self):
self.data = dict()
def __delitem__(self, key):
del self.data[key]
@staticmethod
def _get_key(key):
if fw.has_tensorflow():
if isinstance(key, fw.tensorflow.Tensor):
key = id(key.op)
if fw.has_torch():
if isinstance(key, fw.torch.Tensor):
key = id(key)
return key
def __getitem__(self, key):
ret = self.data[id_dict._get_key(key)]
if isinstance(ret, id_dict.lazy_entry):
return ret.get()
return ret
def __len__(self):
return len(self.data)
def __setitem__(self, key, value):
self.data[id_dict._get_key(key)] = value

View File

@@ -1,4 +1,4 @@
foreach(PROG dot copy)
foreach(PROG dot copy conv)
set(TARGET bench_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})

39
tests/bench/conv.cc Normal file
View File

@@ -0,0 +1,39 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "conv.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs = {
// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1},
// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1},
{1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1},
// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1}
};
int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w;
for(const auto& c: configs){
std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c;
std::cout << "// " << c ;
for(auto perf: bench_conv(stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

View File

@@ -9,11 +9,40 @@ int main() {
// shapes to benchmark
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
for(auto ord: std::vector<std::vector<int>>{{0, 1}})
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
std::vector<config_t> tmp = {
// config_t{ord, x[0], x[1], 128, 128, 128},
// config_t{ord, x[0], x[1], 256, 256, 256},
// config_t{ord, x[0], x[1], 384, 384, 384},
// config_t{ord, x[0], x[1], 512, 512, 512},
// config_t{ord, x[0], x[1], 768, 768, 768},
config_t{ord, x[0], x[1], 1024, 1024, 1024},
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
// config_t{ord, x[0], x[1], 8192, 8192, 8192},
// config_t{ord, x[0], x[1], 256, 16, 256},
// config_t{ord, x[0], x[1], 512, 16, 512},
// config_t{ord, x[0], x[1], 768, 16, 768},
// config_t{ord, x[0], x[1], 1024, 16, 1024},
// config_t{ord, x[0], x[1], 1280, 16, 1280},
// config_t{ord, x[0], x[1], 1536, 16, 1536},
// config_t{ord, x[0], x[1], 2048, 16, 2048},
// config_t{ord, x[0], x[1], 3072, 16, 3072},
// config_t{ord, x[0], x[1], 4096, 16, 4096},
// config_t{ord, x[0], x[1], 5120, 16, 5120},
// config_t{ord, x[0], x[1], 6144, 16, 6144},
// config_t{ord, x[0], x[1], 7168, 16, 7168},
// config_t{ord, x[0], x[1], 64, 64, 4096},
// config_t{ord, x[0], x[1], 64, 64, 8192},
// config_t{ord, x[0], x[1], 64, 64, 16384},
// config_t{ord, x[0], x[1], 64, 64, 32768},
// config_t{ord, x[0], x[1], 64, 64, 65536},
// config_t{ord, x[0], x[1], 64, 64, 131072}
// config_t{ord, x[0], x[1], 127008, 768, 576},
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
// config_t{ord, x[0], x[1], 16, 2048, 2048},

150
tests/common/conv.h Normal file
View File

@@ -0,0 +1,150 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/conv.h"
#include "cuda/cublas.h"
#include "util.h"
struct conv_arg_t{
CUdeviceptr a;
CUdeviceptr b;
CUdeviceptr c;
float alpha;
int M;
int N;
int K;
int pad_h;
int pad_w;
int stride_h;
int stride_w;
CUdeviceptr adelta;
int lda_z;
int lda_ci;
int lda_h;
int lda_w;
int ldb_ci;
int ldb_r;
int ldb_s;
int ldb_co;
int ldc_z;
int ldc_co;
int ldc_p;
int ldc_q;
};
enum run_mode_t {
BENCH,
TEST
};
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
void triton_conv(drv::stream* stream,
int Z, int CI, int H, int W, int CO, int R, int S,
int pad_h, int pad_w, int stride_h, int stride_w,
run_mode_t mode, std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::context* context = stream->context();
int P = (H + 2*pad_h - R)/stride_h + 1;
int Q = (W + 2*pad_w - S)/stride_w + 1;
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CI*H*W*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*CO*dt_nbytes));
auto ddelta = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*4));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
std::vector<int32_t> hdelta(CI*R*S);
int TK = 16;
for(int i = 0; i < hdelta.size(); i++){
int s = i % S;
int cr = i / S;
int r = cr % R;
int c = cr / R;
int nexti = i + TK;
int nexts = nexti % S;
int nextcr = nexti / S;
int nextr = nextcr % R;
int nextc = nextcr / R;
hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s);
}
stream->write(&*ddelta, true, 0, hdelta);
// macros
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.defines.push_back({"TZ", {"1"}});
opt.defines.push_back({"RR", {std::to_string(R)}});
opt.defines.push_back({"SS", {std::to_string(S)}});
opt.defines.push_back({"PP", {std::to_string(P)}});
opt.defines.push_back({"QQ", {std::to_string(Q)}});
opt.defines.push_back({"HH", {std::to_string(H)}});
opt.defines.push_back({"WW", {std::to_string(W)}});
opt.num_warps = {2, 4};
// kernels
rt::function function(src::conv, opt);
conv_arg_t args{*da->cu(), *db->cu(), *dc->cu(), 1, Z*P*Q, CO, CI*R*S,
pad_h, pad_w, stride_h, stride_w,
*ddelta->cu(),
W*H*CI, W*H, W, 1,
CO*S*R , CO*S, CO, 1,
Q*P*CO, Q*P, Q, 1};
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
ceil(CO , x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
}
std::vector<double> bench_conv(drv::stream* stream, dtype_t dtype,
int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S,
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF: triton_conv<half_float::half>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case FLOAT: triton_conv<float>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case DOUBLE: triton_conv<double>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
default: break;
}
return bench;
}

View File

@@ -12,6 +12,15 @@ int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides
return res;
}
struct copy_arg_t{
CUdeviceptr X;
CUdeviceptr Y;
int S0;
int S1;
int S2;
};
enum run_mode_t {
BENCH,
TEST
@@ -115,16 +124,16 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
// kernel
rt::function function(src::copy_nd[rank - 1], opt);
std::vector<rt::arg> args = {&*dx, &*dy};
for(int32_t d: shape)
args.push_back(d);
copy_arg_t args = {*dx->cu(), *dy->cu(), shape[0]};
if(shape.size() > 1) args.S1 = shape[1];
if(shape.size() > 2) args.S2 = shape[2];
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape, ts);
// metrics
if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
bench.push_back(gbps(triton_ns));
}
@@ -136,7 +145,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
for(size_t i = 0; i < hx.size(); i++)
hx[i] = static_cast<T>((float)rand()/RAND_MAX);
stream->write(&*dx, true, 0, hx);
function(args, grid, stream);
function((void**)&args, sizeof(args), grid, stream);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cc_copy_nd(hx, ry, shape, x_order, y_order);

View File

@@ -2,6 +2,7 @@
#include <cstring>
#include <sstream>
#include <cstdio>
#include <tuple>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
@@ -12,6 +13,20 @@
#include "util.h"
struct dot_arg_t{
CUdeviceptr a;
CUdeviceptr b;
CUdeviceptr c;
float alpha;
int M;
int N;
int K;
int lda;
int ldb;
int ldc;
CUdeviceptr locks;
};
template<class T, bool AT, bool BT>
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
@@ -64,7 +79,7 @@ template<> struct to_string<double>{
};
template<class T>
bool triton_dot(drv::stream* stream, bool AT, bool BT,
void triton_dot(drv::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
int32_t TM, int32_t TN, int32_t TK, int32_t nwarp,
const std::vector<int>& a_order, const std::vector<int>& b_order,
@@ -79,9 +94,11 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
std::vector<std::string> sb = { "1", "ldb" };
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
// macros
rt::function::options_space_t opt;
@@ -106,33 +123,41 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.defines.push_back({"TM", {std::to_string(TM)}});
opt.defines.push_back({"TN", {std::to_string(TN)}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.defines.push_back({"TZ", {"1"}});
opt.num_warps = {nwarp};
}
if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"32"}});
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}});
opt.defines.push_back({"TZ", {"1"}});
opt.num_warps = {4};
}
// kernels
rt::function function(src::dot, opt);
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
auto grid = grid2d(M, N);
dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(),
1, M, N, K, lda, ldb, ldc, *dlocks->cu()};
auto grid = [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
// metrics
if(mode == BENCH){
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
// // cublas
// cublas
// if(cublas::cublasinit()){
// T alpha(static_cast<double>(1));
// T beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream);
// bench.push_back(tflops(cublas_ms));
@@ -154,7 +179,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// run kernel
function(args, grid, stream);
function((void**)&args, sizeof(args), grid, stream);
// write back
stream->synchronize();
// compare with CPU

View File

@@ -13,6 +13,15 @@
namespace drv = triton::driver;
namespace rt = triton::runtime;
struct reduce_arg_t{
CUdeviceptr X;
CUdeviceptr Y;
int S0;
int S1;
int S2;
};
template<class T>
void cc_reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
assert(axis <= shapes.size() - 1);
@@ -123,16 +132,16 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
// grid
std::vector<rt::arg> args = {&*dx, &*dy};
for(int32_t d: shape_x)
args.push_back(d);
reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]};
if(shape_x.size() > 1) args.S1 = shape_x[1];
if(shape_x.size() > 2) args.S2 = shape_x[2];
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape_x, ts);
// metrics
if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
bench.push_back(gbps(triton_ns));
}
@@ -144,7 +153,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
init_zeros(hy);
init_rand(hx);
stream->write(&*dx, true, 0, hx);
function(args, grid, stream);
function((void**)&args, sizeof(args), grid, stream);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cc_reduce_nd(ry, hx, op, axis, shape_x);

134
tests/common/src/conv.h Normal file
View File

@@ -0,0 +1,134 @@
namespace src {
const char *conv =
R"(
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M __retune,
int N __retune,
int K __retune,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs [TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr [TK] = rcir % RR;
int rci [TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz [:, newaxis] * lda_z +
rci[newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE* pa[TM, TK] = A + offa;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr [:, newaxis] * ldb_r +
rs [:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE* pb[TK, TN] = B + offb;
// pointers to delta
int* padelta[TK] = ADELTA + rk;
int adelta[TK] = *padelta;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkbn[TK, TN] = rn[newaxis, :] < N;
bool checkbk[TK, TN] = rk[:, newaxis] < K;
bool checkb[TK, TN] = checkbn && checkbk;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
// increment A
pa += adelta[newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkbk = k > TK;
bool checkb[TK, TN] = checkbn && checkbk;
a = checka ? *pa : 0;
b = *?(checkb)pb;
padelta += TK;
adelta = *padelta;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co +
rp [:, newaxis] * ldc_p +
rm [:, newaxis] * 1;
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";
}

View File

@@ -7,7 +7,7 @@ namespace src {
R"(
void copy1d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __readonly __aligned(16),
int S0) {
int S0 __retune) {
int pid0 = get_program_id(0);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
TYPE* px[TS0] = X + rs0;
@@ -20,8 +20,8 @@ void copy1d(TYPE * X __noalias __readonly __aligned(16),
R"(
void copy2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0 __multipleof(8),
int S1 __multipleof(8)) {
int S0 __multipleof(8) __retune,
int S1 __multipleof(8) __retune) {
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
@@ -37,9 +37,9 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
R"(
void copy3d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0 __multipleof(8),
int S1 __multipleof(8),
int S2 __multipleof(8)) {
int S0 __multipleof(8) __retune,
int S1 __multipleof(8) __retune,
int S2 __multipleof(8) __retune) {
// program id
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);

View File

@@ -6,13 +6,17 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M, int N, int K,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
int ldc __multipleof(8),
int* locks) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
@@ -20,7 +24,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// pointers to operands
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
@@ -35,9 +42,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE b[SHAPE_B] = checkb ? *pb : 0;
// reduction loop
float c[TM, TN] = 0;
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
c += USEA @ USEB;
acc += USEA @ USEB;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
pa += TK * STRIDE_AK;
@@ -45,7 +52,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
a = *?(checka)pa;
b = *?(checkb)pb;
}
//c = c * alpha;
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
@@ -53,7 +61,22 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
*?(checkc)pc = (TYPE[TM, TN])c;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";

View File

@@ -16,7 +16,7 @@ int main() {
for(int nwarps: std::vector<int>{4})
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{HALF, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
}
// test
dtype_t dtype;