Compare commits
44 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
526a3468e0 | ||
|
5700260e5e | ||
|
91d21134a7 | ||
|
010f85e91d | ||
|
3195aca452 | ||
|
d5610b685b | ||
|
30ac1359b9 | ||
|
8e5744f171 | ||
|
2d6484482f | ||
|
bd2067606c | ||
|
f01bdd6207 | ||
|
ec923cfc5e | ||
|
cfcda09679 | ||
|
3683b38963 | ||
|
cc7c77246b | ||
|
baabaf0d58 | ||
|
8e3ec0556f | ||
|
7c519e09c5 | ||
|
8e9d793d11 | ||
|
547434d7f0 | ||
|
c1f6ce68f9 | ||
|
b48460f51c | ||
|
52eca7676b | ||
|
a92031b214 | ||
|
18a4cefec7 | ||
|
735215d9c4 | ||
|
05214d22e3 | ||
|
d390d941d8 | ||
|
f8614b2490 | ||
|
8f9233e546 | ||
|
cd21151b98 | ||
|
9da8fe11ed | ||
|
65a41af8dc | ||
|
142a1dd42b | ||
|
f88a7a0027 | ||
|
bfa85ef030 | ||
|
3848d25450 | ||
|
1a20099582 | ||
|
fc3747d59c | ||
|
5fddc2062e | ||
|
91bfb0d182 | ||
|
be02315168 | ||
|
2dd3fcb270 | ||
|
f5d47536c5 |
@@ -21,7 +21,7 @@ endif()
|
||||
|
||||
# Compiler flags
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++11")
|
||||
|
||||
# Tests
|
||||
if(BUILD_TESTS)
|
||||
@@ -33,8 +33,10 @@ endif()
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
set(PYTHON_SRC bindings.cc)
|
||||
set(PYTHON_SRC bindings.cc launch.cc)
|
||||
set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -43,3 +45,6 @@ file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES})
|
||||
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES})
|
||||
endif()
|
||||
|
30
README.md
30
README.md
@@ -1,29 +1,41 @@
|
||||
# Triton
|
||||
|
||||
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with much higher flexibility than [TVM](https://github.com/apache/incubator-tvm) and without having to manually specify compute schedules.
|
||||
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
|
||||
|
||||
The main components of Triton at the moment are:
|
||||
|
||||
- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level (think numpy-like array operations in a C-like language).
|
||||
- **Triton-IR**: A special-purpose intermediate representation (Triton-IR) for aiding array-level program analysis and optimizations in Triton-C programs.
|
||||
- **Triton-JIT**: An optimizing just-in-time compiler for Triton-IR, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)). This includes transparent support for mixed-precision and Tensor Cores.
|
||||
|
||||
Bindings for **automatic** PyTorch custom op generations are included in **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc.
|
||||
|
||||
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please cite us if you use our work!
|
||||
|
||||
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM-8.0+ for code generation.
|
||||
|
||||
|
||||
You can install the latest release with pip as follows:
|
||||
```
|
||||
sudo apt-get install llvm-9-dev
|
||||
pip install triton
|
||||
```
|
||||
|
||||
or the latest development version with:
|
||||
```
|
||||
sudo apt-get install llvm-8-dev
|
||||
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
```
|
||||
|
||||
for the C++ package:
|
||||
```
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
mkdir build;
|
||||
cd build;
|
||||
cmake ../;
|
||||
make -j8;
|
||||
```
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -32,8 +44,4 @@ Please visit the [documentation](https://docs.triton-lang.org) to get started wi
|
||||
|
||||
## Contributing
|
||||
|
||||
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!
|
||||
|
||||
## ISAAC (deprecated) for fast inference
|
||||
|
||||
Before working on Triton, I wrote custom auto-tuned PTX code for fast, quantized inference on GPUs. While this project is now deprecated, you can use it at your own risk by checking out the "isaac" tag in this repository.
|
||||
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!
|
@@ -1,6 +1,15 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
Triton can be installed directly from pip with the following command
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install triton
|
||||
|
||||
|
||||
See the information below for more detailed information on custom builds.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
|
@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
"""
|
||||
# create callable kernel for the source-code
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
|
||||
|
||||
# Forward pass
|
||||
@staticmethod
|
||||
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
zb = add(x, y)
|
||||
diff = (za - zb).abs().max()
|
||||
print(diff)
|
||||
print(torch.allclose(za,zb))
|
||||
|
||||
Executing the above code will:
|
||||
|
||||
@@ -97,3 +98,5 @@ Executing the above code will:
|
||||
- Call the resulting custom op
|
||||
|
||||
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.
|
||||
|
@@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
@@ -24,7 +24,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
|
||||
int rm[TM] = pm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pn * TN + 0 ... TN; //(4)
|
||||
int rk[TK] = 0 ... TK; //(5)
|
||||
// initialize accumulator
|
||||
// initialize accumulator
|
||||
float c[TM, TN] = 0; //(6)
|
||||
// pointers to operands
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
|
||||
@@ -32,10 +32,10 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
// fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
// matrix-multiply accumulate
|
||||
c += dot(a, b); //(11)
|
||||
c += a @ b; //(11)
|
||||
// increment pointers
|
||||
pa = pa + TK * 1; //(12)
|
||||
pb = pb + TK * ldb; //(13)
|
||||
@@ -85,10 +85,10 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
|
||||
.. code-block:: C
|
||||
|
||||
// pre-fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += dot(a, b);
|
||||
c += a @ b;
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
// don't prefetch last iteration
|
||||
@@ -113,7 +113,7 @@ Rematerialization
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
*pc = c;
|
||||
|
||||
|
||||
------------------------------------
|
||||
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += dot(USE_A, USE_B);
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
@@ -163,16 +163,16 @@ All matrix multiplications variants can then be retrieved using the following co
|
||||
.. code-block:: C
|
||||
|
||||
// A is not transposed
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
|
||||
// A is transposed
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
|
||||
// B is not transpose
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
|
||||
// B is transpose
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
|
||||
|
||||
|
||||
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
|
||||
|
||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.
|
||||
|
@@ -3,7 +3,7 @@ Matrix Transpositions
|
||||
*********************
|
||||
|
||||
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
|
||||
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
|
||||
|
||||
@@ -16,7 +16,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M, int N, int ldx, int ldy) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
@@ -25,12 +25,12 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
// write back using the transposition operator '^'
|
||||
*py = ^(*px); //(7)
|
||||
}
|
||||
|
||||
|
||||
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
|
||||
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
|
||||
@@ -54,7 +54,7 @@ which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
|
||||
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
@@ -73,6 +73,65 @@ which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
||||
|
||||
|
||||
==================================
|
||||
A Note on Numpy-style Broadcasting
|
||||
==================================
|
||||
|
||||
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
|
||||
|
||||
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
vec = np.linspace(0,11,12)
|
||||
mat = vec.reshape((4,3))
|
||||
|
||||
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
|
||||
|
||||
::
|
||||
|
||||
[[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11]]
|
||||
|
||||
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
|
||||
|
||||
::
|
||||
|
||||
[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
|
||||
We construct ranges `rm` and `rn` as:
|
||||
::
|
||||
|
||||
rm = [0, 1]
|
||||
rn = [0, 1, 2]
|
||||
|
||||
Using numpy-style broadcasting, we can add these together to create a matrix:
|
||||
|
||||
::
|
||||
|
||||
rm[:, np.newaxis] + rn[np.newaxis, :]
|
||||
|
||||
rn -> [0, 1, 2]
|
||||
rm -> [0., [[0, 1, 2],
|
||||
1.] [1, 2, 3]]
|
||||
|
||||
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
|
||||
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
|
||||
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
|
||||
In this case this is 3, so what we really need is:
|
||||
|
||||
::
|
||||
|
||||
ldx = 3
|
||||
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
|
||||
|
||||
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
|
||||
|
||||
==========================
|
||||
The __multipleof attribute
|
||||
==========================
|
||||
@@ -83,19 +142,19 @@ For this reason, Triton provides a __multipleof(N) attributes for variables that
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
int ldy __multipleof(8)) {
|
||||
// ...
|
||||
}
|
||||
|
||||
|
||||
|
||||
==========================
|
||||
Bounds Checking
|
||||
==========================
|
||||
|
||||
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
@@ -108,6 +167,8 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.
|
||||
|
@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
|
||||
'TYPE' : dtype,
|
||||
'AT' : transpose_a,
|
||||
'BT' : transpose_b,
|
||||
'TM' : [32, 64, 128]
|
||||
'TN' : [32, 64, 128]
|
||||
'TK' : [8]
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [8],
|
||||
# handle A transposition
|
||||
'USE_A' : '^a' if transpose_a else 'a',
|
||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||
@@ -177,4 +177,4 @@ A callable operation can be created using the :code:`apply` method of the :code:
|
||||
dot = _dot.apply
|
||||
|
||||
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
||||
|
@@ -11,6 +11,7 @@ namespace ir {
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
@@ -44,6 +45,7 @@ private:
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
@@ -53,6 +55,7 @@ private:
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
|
@@ -19,7 +19,7 @@ namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
|
@@ -32,7 +32,7 @@ public:
|
||||
driver::context* context() const;
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -66,7 +66,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -80,7 +80,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -39,7 +39,9 @@ enum attribute_kind_t {
|
||||
writeonly,
|
||||
noalias,
|
||||
aligned,
|
||||
multiple_of
|
||||
multiple_of,
|
||||
retune,
|
||||
not_implemented
|
||||
};
|
||||
|
||||
class attribute {
|
||||
@@ -112,6 +114,7 @@ public:
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
// visitor
|
||||
|
@@ -20,15 +20,18 @@ class visitor;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class value {
|
||||
public:
|
||||
typedef std::set<user*> users_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
value(type *ty, const std::string &name = "");
|
||||
virtual ~value(){ }
|
||||
// uses
|
||||
void add_use(user* arg);
|
||||
unsigned erase_use(user* arg);
|
||||
users_t::iterator erase_use(user* arg);
|
||||
const std::set<user*> &get_users() { return users_; }
|
||||
virtual void replace_all_uses_with(value *target);
|
||||
void replace_all_uses_with(value *target);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
const std::string &get_name() const { return name_; }
|
||||
@@ -41,7 +44,7 @@ private:
|
||||
|
||||
protected:
|
||||
type *ty_;
|
||||
std::set<user*> users_;
|
||||
users_t users_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -63,6 +66,7 @@ public:
|
||||
user(type *ty, unsigned num_ops, const std::string &name = "")
|
||||
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
|
||||
}
|
||||
virtual ~user() { }
|
||||
|
||||
// Operands
|
||||
const ops_t& ops() { return ops_; }
|
||||
@@ -74,8 +78,7 @@ public:
|
||||
unsigned get_num_hidden() const;
|
||||
|
||||
// Utils
|
||||
void replace_all_uses_with(value *target);
|
||||
void replace_uses_of_with(value *before, value *after);
|
||||
value::users_t::iterator replace_uses_of_with(value *before, value *after);
|
||||
|
||||
|
||||
private:
|
||||
|
@@ -64,7 +64,8 @@ public:
|
||||
ALIGNED,
|
||||
NOALIAS,
|
||||
READONLY,
|
||||
WRITEONLY
|
||||
WRITEONLY,
|
||||
RETUNE,
|
||||
};
|
||||
|
||||
KindT kind;
|
||||
|
@@ -53,13 +53,14 @@ inline bool is_int_type(arg_type ty){
|
||||
}
|
||||
|
||||
class arg {
|
||||
private:
|
||||
public:
|
||||
union value_t {
|
||||
bool int1;
|
||||
int8_t int8;
|
||||
int16_t int16;
|
||||
int32_t int32;
|
||||
int64_t int64;
|
||||
uint16_t fp16;
|
||||
float fp32;
|
||||
double fp64;
|
||||
driver::buffer* buf;
|
||||
@@ -67,6 +68,7 @@ private:
|
||||
|
||||
public:
|
||||
// construct from primitive types
|
||||
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
|
||||
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
|
||||
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
|
||||
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#ifndef _TRITON_RUNTIME_FUNCTION_H_
|
||||
#define _TRITON_RUNTIME_FUNCTION_H_
|
||||
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@@ -62,6 +62,7 @@ public:
|
||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||
std::vector<define_t> defines;
|
||||
std::vector<int> num_warps;
|
||||
std::vector<int> recompile_key;
|
||||
};
|
||||
|
||||
struct options_t {
|
||||
@@ -94,19 +95,25 @@ private:
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
const driver::kernel* bin() const { return &*bin_; }
|
||||
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
|
||||
const std::vector<arg_type>& param_tys() const { return param_tys_; }
|
||||
|
||||
std::vector<int> retune() const { return retune_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
std::vector<int> retune_;
|
||||
options_t opt_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
|
||||
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
|
||||
|
||||
private:
|
||||
// cache
|
||||
@@ -118,16 +125,15 @@ private:
|
||||
caller *make(driver::stream *stream, options_t opt);
|
||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
function::cache_key_t get_key(driver::stream *stream, const std::vector<arg>& args);
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
|
||||
private:
|
||||
@@ -138,6 +144,8 @@ private:
|
||||
options_space_t opt_;
|
||||
std::set<options_t> compiled_;
|
||||
std::map<options_t, std::unique_ptr<caller>> callers_;
|
||||
std::vector<int> args_off_;
|
||||
size_t args_size_;
|
||||
// caching
|
||||
std::string cache_ref_;
|
||||
std::string cache_path_;
|
||||
|
@@ -38,19 +38,22 @@ inline double bench(std::function<void()> const & op, driver::stream * stream, b
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
float norm = 1;
|
||||
tmr.start();
|
||||
for(size_t i = 0; i < 10; i++){
|
||||
// while(total_time*1e-9 < 1e-2){
|
||||
// float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(normalize)
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
stream->synchronize();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
// times.push_back(norm*tmr.get().count());
|
||||
// total_time+=times.back();
|
||||
}
|
||||
return *std::min_element(times.begin(), times.end());
|
||||
stream->synchronize();
|
||||
return (float)tmr.get().count() / 10;
|
||||
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -287,10 +287,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst > 0)
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
if(rhs_cst_info[d].num_cst > 0)
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
@@ -329,9 +327,16 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
||||
auto result = populate_max_contiguous(v->get_operand(0));
|
||||
return add_to_cache(v, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
@@ -436,16 +441,16 @@ std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_tile_ty()) {
|
||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
@@ -466,6 +471,13 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
@@ -510,6 +522,11 @@ void align::populate(ir::value *v) {
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
|
||||
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// });
|
||||
}
|
||||
|
||||
|
||||
|
@@ -22,6 +22,30 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
inline int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
|
||||
using llop = llvm::Instruction::BinaryOps;
|
||||
@@ -153,6 +177,7 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) {
|
||||
case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly);
|
||||
case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly);
|
||||
case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value());
|
||||
case ir::retune: return llvm::Attribute::get(ctx, llvm::Attribute::None);
|
||||
default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute");
|
||||
}
|
||||
}
|
||||
@@ -309,7 +334,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
unsigned contiguous = 1;
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
contiguous = result->axis(ld).contiguous;
|
||||
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||
unsigned vector_size = gcd(contiguous, alignment);
|
||||
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
@@ -329,7 +354,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
unsigned contiguous = 1;
|
||||
if(ld < x->get_type()->get_tile_rank())
|
||||
contiguous = result->axis(ld).contiguous;
|
||||
unsigned vector_size = std::min<unsigned>(contiguous, alignment);
|
||||
unsigned vector_size = gcd(contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size));
|
||||
@@ -347,7 +372,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
@@ -400,7 +425,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// extract result element
|
||||
for_each(x, [&](indices_t idx){
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
unsigned vector_size = gcd(result->axis(ld).contiguous, alignment);
|
||||
unsigned linear = result->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2});
|
||||
@@ -418,6 +443,8 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
|
||||
distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand());
|
||||
@@ -425,7 +452,7 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
int vector_size = 1;
|
||||
int ld = ptrs->get_order()[0];
|
||||
unsigned alignment = alignment_->get(st->get_pointer_operand(), ld);
|
||||
vector_size = std::min<unsigned>(ptrs->axis(ld).contiguous, alignment);
|
||||
vector_size = gcd(ptrs->axis(ld).contiguous, alignment);
|
||||
// create packets
|
||||
std::map<unsigned, Value*> packets;
|
||||
ir::value *arg = st->get_value_operand();
|
||||
@@ -556,14 +583,14 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
|
||||
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
|
||||
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
||||
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
|
||||
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
|
||||
|
||||
|
||||
for_each(x, [&](indices_t idx){
|
||||
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
|
||||
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
|
||||
set_value(x, idx, builder_->CreateCall(ex2, std::vector<llvm::Value*>{ex2arg}));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -584,7 +611,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
AtomicOrdering::Monotonic);
|
||||
old = builder_->CreateExtractValue(old, {0});
|
||||
old = builder_->CreateExtractValue(old, std::vector<unsigned>{0});
|
||||
Value *atom_ptr;
|
||||
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
|
||||
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
|
||||
@@ -618,8 +645,25 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
||||
throw std::runtime_error("unsupported");
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = vmap_.at(add->get_operand(0));
|
||||
Value *rmw_val = vmap_.at(add->get_operand(1));
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
@@ -640,8 +684,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
|
||||
Type *fp32_ty = builder_->getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||
@@ -720,15 +764,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder_->CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder_->CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder_->CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder_->CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder_->CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder_->CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder_->CreateExtractValue(nc, {7});
|
||||
Value *nc = builder_->CreateCall(mma_fn, std::vector<llvm::Value*>{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{0});
|
||||
fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{1});
|
||||
fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{2});
|
||||
fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{3});
|
||||
fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{4});
|
||||
fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{5});
|
||||
fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{6});
|
||||
fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{7});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -770,7 +814,7 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_ti
|
||||
a = builder_->CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder_->CreateFPCast(b, c_ty);
|
||||
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
||||
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
|
||||
}
|
||||
set_value(dot, idx, res);
|
||||
});
|
||||
@@ -790,7 +834,7 @@ void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distrib
|
||||
a = builder_->CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder_->CreateFPCast(b, c_ty);
|
||||
res = builder_->CreateCall(f_mul_add, {a, b, res});
|
||||
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
|
||||
set_value(dot, idx, res);
|
||||
});
|
||||
}
|
||||
@@ -805,7 +849,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
|
||||
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
|
||||
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_);
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
|
||||
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = 1;
|
||||
unsigned NK = A_shapes[red_axis];
|
||||
@@ -835,8 +879,8 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
|
||||
for_each(sqt, [&](indices_t idx){
|
||||
Value *val = get_value(sqt->get_operand(0), idx);
|
||||
Module* module = builder_->GetInsertBlock()->getModule();
|
||||
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()});
|
||||
Value *ret = builder_->CreateCall(sqrt, {val});
|
||||
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector<llvm::Type*>{val->getType()});
|
||||
Value *ret = builder_->CreateCall(sqrt, std::vector<llvm::Value*>{val});
|
||||
set_value(sqt, idx, ret);
|
||||
});
|
||||
}
|
||||
@@ -849,7 +893,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
unsigned axis = x->get_axis();
|
||||
|
||||
Type *fp32_ty = builder_->getFloatTy();
|
||||
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
|
||||
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty}, false);
|
||||
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
|
||||
|
||||
@@ -871,8 +915,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
}
|
||||
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
||||
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
||||
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
|
||||
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
|
||||
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector<llvm::Value*>{x, y});
|
||||
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector<llvm::Value*>{x, y});
|
||||
default: assert(false); return nullptr;
|
||||
}
|
||||
};
|
||||
@@ -910,11 +954,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
thread_acc = accumulate(thread_acc, current);
|
||||
});
|
||||
// reduce within wrap
|
||||
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
|
||||
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector<llvm::Type*>{thread_acc->getType(), builder_->getInt32Ty()}, false);
|
||||
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||
Value *warp_acc = thread_acc;
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
|
||||
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{warp_acc, builder_->getInt32(i)}));
|
||||
// shared memory pointer
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *res_ty = arg_tile->get_ty();
|
||||
@@ -935,7 +979,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
builder_->SetInsertPoint(bb_final_acc);
|
||||
Value* final_val = builder_->CreateLoad(load_ptr);
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
|
||||
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
|
||||
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{final_val, builder_->getInt32(i)}));
|
||||
builder_->CreateStore(final_val, load_ptr);
|
||||
builder_->CreateBr(bb_final_acc_done);
|
||||
// // store first warp done
|
||||
|
@@ -109,14 +109,18 @@ void coalesce::run(ir::module &mod) {
|
||||
// extract leading axes
|
||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||
for(ir::io_inst *i: io){
|
||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
|
||||
if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){
|
||||
extract_ld(i, axes);
|
||||
}
|
||||
}
|
||||
// update list of values to rematerialize
|
||||
if(axes.empty())
|
||||
continue;
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++){
|
||||
if(it->second.size() == 1)
|
||||
continue;
|
||||
remat.insert(remat.begin(), it->second.begin(), it->second.end());
|
||||
}
|
||||
}
|
||||
// rematerialize values
|
||||
for(ir::io_inst *r: remat) {
|
||||
|
@@ -34,10 +34,13 @@ void disassociate::run(ir::module &mod) {
|
||||
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i)){
|
||||
ir::value* op = i->get_operand(0);
|
||||
if(!dynamic_cast<ir::user*>(op))
|
||||
return;
|
||||
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
|
||||
return;
|
||||
std::map<int, std::set<ir::user*>> chains;
|
||||
std::set<ir::value*> seen;
|
||||
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
|
||||
return;
|
||||
extract_retile_chain(i, chains, 0, seen);
|
||||
if(chains.size())
|
||||
clone_info[i] = chains;
|
||||
|
@@ -81,20 +81,21 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
if(cfs) {
|
||||
ir::value *arg = cfs->get_operand(0);
|
||||
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
if(!cts)
|
||||
return false;
|
||||
cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
return true;
|
||||
}
|
||||
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
// if(cfs) {
|
||||
// ir::value *arg = cfs->get_operand(0);
|
||||
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
// if(!cts)
|
||||
// return false;
|
||||
// cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
// return true;
|
||||
// }
|
||||
|
||||
}
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
@@ -196,7 +197,7 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
|
@@ -146,7 +146,8 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
|
||||
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
@@ -250,11 +251,10 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
int version;
|
||||
dispatch::cuDriverGetVersion(&version);
|
||||
int major = version / 1000;
|
||||
int minor = (version - major*1000) / 10;
|
||||
// int minor = (version - major*1000) / 10;
|
||||
if(major < 10)
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
if(minor >= 1)
|
||||
find_and_replace(result, ".version", "\n", ".version 6.4\n");
|
||||
find_and_replace(result, ".version", "\n", ".version 6.4\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
return result;
|
||||
@@ -265,6 +265,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
|
@@ -79,7 +79,7 @@ void host_stream::synchronize() {
|
||||
|
||||
}
|
||||
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
|
||||
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
|
||||
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
|
||||
@@ -112,7 +112,7 @@ void cl_stream::synchronize() {
|
||||
check(dispatch::clFinish(*cl_));
|
||||
}
|
||||
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
}
|
||||
@@ -149,12 +149,11 @@ void cu_stream::synchronize() {
|
||||
dispatch::cuStreamSynchronize(*cu_);
|
||||
}
|
||||
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event) {
|
||||
driver::cu_kernel* cu_kernel = (driver::cu_kernel*)kernel;
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** extra) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->first, *cu_);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)cu_kernel->cu_params(), NULL);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->second, *cu_);
|
||||
}
|
||||
|
@@ -70,9 +70,10 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
return phi;
|
||||
// unique value or self-reference
|
||||
ir::value *same = *non_self_ref.begin();
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
assert(same != nullptr);
|
||||
phi->replace_all_uses_with(same);
|
||||
phi->erase_from_parent();
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
for(ir::user* u: users)
|
||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||
if(uphi != phi)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
@@ -19,8 +20,11 @@ void value::add_use(user *arg) {
|
||||
users_.insert(arg);
|
||||
}
|
||||
|
||||
unsigned value::erase_use(user *arg){
|
||||
return users_.erase(arg);
|
||||
value::users_t::iterator value::erase_use(user *arg){
|
||||
auto it = users_.find(arg);
|
||||
if(it == users_.end())
|
||||
return it;
|
||||
return users_.erase(it);
|
||||
}
|
||||
|
||||
// TODO: automatic naming scheme + update symbol table
|
||||
@@ -29,9 +33,12 @@ void value::set_name(const std::string &name){
|
||||
}
|
||||
|
||||
void value::replace_all_uses_with(value *target){
|
||||
throw std::runtime_error("not implemented");
|
||||
for (auto it = users_.begin(); it != users_.end(); ) {
|
||||
it = (*it)->replace_uses_of_with(this, target);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void visitor::visit_value(ir::value* v) {
|
||||
v->accept(this);
|
||||
}
|
||||
@@ -59,18 +66,12 @@ unsigned user::get_num_hidden() const {
|
||||
return num_hidden_;
|
||||
}
|
||||
|
||||
void user::replace_all_uses_with(value *target) {
|
||||
for(auto it = users_.begin(); it != users_.end(); it++){
|
||||
(*it)->replace_uses_of_with(this, target);
|
||||
}
|
||||
}
|
||||
|
||||
void user::replace_uses_of_with(value *before, value *after) {
|
||||
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
||||
for(size_t i = 0; i < ops_.size(); i++)
|
||||
if(ops_[i] == before)
|
||||
ops_[i] = after;
|
||||
after->add_use(this);
|
||||
before->erase_use(this);
|
||||
return before->erase_use(this);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() {
|
||||
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
||||
Error(this, "tile with more than one element cannot be casted to scalar");
|
||||
if(type_->IsTile() && operandType->IsTile()){
|
||||
auto shape = type_->ToTile()->Shape();
|
||||
auto operandShape = operandType->ToTile()->Shape();
|
||||
if(operandShape.size() > shape.size())
|
||||
Error(this, "cast cannot reduce operand rank");
|
||||
auto shape = type_->ToTile()->Shape();
|
||||
// this is a shape downcast
|
||||
if(operandShape.size() > shape.size()){
|
||||
size_t operandNumel = 1;
|
||||
size_t numel = 1;
|
||||
for(auto x: operandShape)
|
||||
operandNumel *= x;
|
||||
for(auto x: shape)
|
||||
numel *= x;
|
||||
if(operandNumel != numel)
|
||||
Error(this, "cast cannot change number of elements");
|
||||
return;
|
||||
}
|
||||
// this is a shape upcast
|
||||
while(operandShape.size() < shape.size())
|
||||
operandShape.insert(operandShape.begin(), 1);
|
||||
for(size_t i = 0; i < shape.size(); i++) {
|
||||
|
@@ -277,6 +277,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||
}
|
||||
if(name == "f32_atomic_add"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_add(ptr, val));
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
@@ -623,7 +630,10 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
|
||||
return ir::attribute(ir::readonly);
|
||||
if(attr.kind == ASTNode::Attr::WRITEONLY)
|
||||
return ir::attribute(ir::writeonly);
|
||||
if(attr.kind == ASTNode::Attr::RETUNE)
|
||||
return ir::attribute(ir::retune);
|
||||
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
|
||||
return ir::attribute(ir::not_implemented);
|
||||
}
|
||||
|
||||
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {
|
||||
|
@@ -2778,6 +2778,8 @@ ASTNode::Attr Parser::ParseAttribute() {
|
||||
ret.kind = ASTNode::Attr::MULTIPLEOF;
|
||||
else if(name == "noalias")
|
||||
ret.kind = ASTNode::Attr::NOALIAS;
|
||||
else if(name == "retune")
|
||||
ret.kind = ASTNode::Attr::RETUNE;
|
||||
else
|
||||
Error(tok, "unknown attribute kind");
|
||||
// set exprs
|
||||
|
@@ -151,27 +151,23 @@ function::caller::caller(ir::function *ir,
|
||||
bin_.reset(driver::kernel::create(&*parent, name_.c_str()));
|
||||
// extract signature
|
||||
ir::function_type* ty = ir->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++){
|
||||
param_tys_.push_back(convert(ty->get_param_ty(i)));
|
||||
if(!ir->has_attr(i+1))
|
||||
continue;
|
||||
for(ir::attribute attr: ir->attrs().at(i + 1))
|
||||
if(attr.get_kind() == ir::retune)
|
||||
retune_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector<arg>& args) const {
|
||||
if(args.size() != param_tys_.size())
|
||||
throw std::runtime_error("invalid number of arguments");
|
||||
// set arguments
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg arg_i = args.at(i);
|
||||
arg_type ty = arg_i.type();
|
||||
if(ty != param_tys_.at(i))
|
||||
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
||||
if(ty == BUFFER_T){
|
||||
driver::buffer* buf = *((driver::buffer**)arg_i.data());
|
||||
bin_->setArg(i, buf->size() == 0 ? nullptr : buf);
|
||||
}
|
||||
else
|
||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||
}
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const {
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
// set grid
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
@@ -179,7 +175,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
// enqueue
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1});
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, config);
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +214,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
disassociate.run(module);
|
||||
dce.run(module);
|
||||
@@ -252,20 +247,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
|
||||
// create Binary from options
|
||||
function::caller* function::make(driver::stream *stream, options_t opt) {
|
||||
// cache path
|
||||
std::string cache_path = cache_path_ + opt.to_str() + ".ptx";
|
||||
int ref_mtime = tools::mtime(cache_ref_);
|
||||
int ptx_mtime = tools::mtime(cache_path);
|
||||
// if cached ptx is newer than reference library
|
||||
if(!ref_mtime || !ptx_mtime || ref_mtime < ptx_mtime){
|
||||
std::ifstream ifs(cache_path);
|
||||
// file is empty -- invalid
|
||||
if(ifs && ifs.peek() == std::ifstream::traits_type::eof())
|
||||
return nullptr;
|
||||
// load cached caller
|
||||
if(ifs)
|
||||
return new caller(stream->context(), ifs, opt);
|
||||
}
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src_, true);
|
||||
@@ -282,18 +263,11 @@ function::caller* function::make(driver::stream *stream, options_t opt) {
|
||||
try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error&){
|
||||
if(!cache_path_.empty())
|
||||
std::ofstream ofs(cache_path);
|
||||
return nullptr;
|
||||
}
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller* ret = new caller(tmp, std::move(bin), opt);
|
||||
// serialize callable
|
||||
if(!cache_path_.empty()){
|
||||
std::ofstream ofs(cache_path);
|
||||
ret->write(ofs);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -331,49 +305,25 @@ void function::precompile(driver::stream* stream,
|
||||
throw std::runtime_error("could not compile kernel");
|
||||
}
|
||||
|
||||
// return auto-tuning key for given function arguments
|
||||
function::cache_key_t function::get_key(driver::stream *stream, const std::vector<arg>& args) {
|
||||
cache_key_t ret;
|
||||
ret.first = stream->context()->device();
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg_type ty = args.at(i).type();
|
||||
if(!is_int_type(ty))
|
||||
continue;
|
||||
long val = 0;
|
||||
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
|
||||
ret.second.push_back(val);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
// returns program with best compilation options for given parameter
|
||||
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||
const std::vector<arg>& args) {
|
||||
// // copy buffer argument so that auto-tuning doesn't corrupt data
|
||||
// std::list<std::shared_ptr<driver::cu_buffer>> copies;
|
||||
// std::vector<arg> args = args;
|
||||
// for(arg& x: args)
|
||||
// if(x.type() == BUFFER_T){
|
||||
// driver::buffer* old = x.buffer();
|
||||
// driver::context* ctx = old->context();
|
||||
// size_t size = old->size();
|
||||
// copies.push_back(std::make_shared<driver::cu_buffer>(ctx, size));
|
||||
// x = arg(copies.back().get());
|
||||
// }
|
||||
void** args, size_t args_size) {
|
||||
// fast path -- no autotuning necessary
|
||||
if(callers_.size() == 1)
|
||||
return &*callers_.begin()->second;
|
||||
// slow path -- autotuning necessary
|
||||
// TODO" copy buffer argument so that auto-tuning doesn't corrupt data
|
||||
double best_ts = INFINITY;
|
||||
caller* ret = nullptr;
|
||||
for(auto &x : callers_){
|
||||
if(x.second == nullptr)
|
||||
throw std::runtime_error("configuration not compiled");
|
||||
caller* current = &*x.second;
|
||||
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args); },
|
||||
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); },
|
||||
stream, true);
|
||||
ret = (ts < best_ts) ? current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
}
|
||||
stream->synchronize();
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -394,12 +344,14 @@ std::string function::preheader() {
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
#define __retune __attribute__((retune))
|
||||
|
||||
#define F32_INFINITY bitcast<float>(0x7F800000)
|
||||
#define F16_INFINITY bitcast<half>((int16)0x7C00)
|
||||
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern float f32_atomic_add(float*, float);
|
||||
extern int get_program_id(int);
|
||||
extern int get_num_programs(int);
|
||||
extern float sqrtf(float);
|
||||
@@ -452,27 +404,35 @@ function::function(const std::string &src,
|
||||
src_ = preheader() + src_;
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args,
|
||||
const grid_fn_ty& grid_fn,
|
||||
driver::stream *stream) {
|
||||
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
// pre-compile kernels
|
||||
if(callers_.empty())
|
||||
if(callers_.empty()){
|
||||
precompile(stream, opt_);
|
||||
size_t cumsum = 0;
|
||||
for(arg_type ty: callers_.begin()->second->param_tys()){
|
||||
args_off_.push_back(cumsum);
|
||||
cumsum += size_of(ty);
|
||||
}
|
||||
}
|
||||
// re-tuning key
|
||||
cache_key_t key;
|
||||
key.first = stream->context()->device();
|
||||
key.second = callers_.begin()->second->retune();
|
||||
// auto-tune if necessary
|
||||
auto key = get_key(stream, args);
|
||||
auto it = cache_.find(key);
|
||||
if(it == cache_.end()){
|
||||
auto best = autotune(stream, grid_fn, args);
|
||||
auto best = autotune(stream, grid_fn, args, args_size);
|
||||
it = cache_.insert({key, best}).first;
|
||||
}
|
||||
// run
|
||||
(*it->second)(stream, grid_fn(it->second->opt()), args);
|
||||
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size);
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args,
|
||||
void function::operator()(void** args,
|
||||
size_t args_size,
|
||||
const grid_t& grid,
|
||||
driver::stream *stream) {
|
||||
return this->operator()(args, [&grid](const options_t&){ return grid; }, stream);
|
||||
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
|
||||
}
|
||||
|
||||
|
||||
|
202
python/examples/tutorials/conv2d.py
Normal file
202
python/examples/tutorials/conv2d.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs [TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr [TK] = rcir % RR;
|
||||
int rci [TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||
rci[newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
int* padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr [:, newaxis] * ldb_r +
|
||||
rs [:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta[newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co+
|
||||
rp [:, newaxis] * ldc_p +
|
||||
rq [:, newaxis] * 1;
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride, time):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if dtype not in _conv.kernel:
|
||||
TK = 8
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [16, 32, 64, 128],
|
||||
'TN' : [16, 32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = triton.empty([Z, CO, P, Q], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||
triton.cdiv(CO, opt.d('TN'))]
|
||||
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
|
||||
pad[0], pad[1], stride[0], stride[1],
|
||||
delta,
|
||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
grid=grid, bench=100)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
conv = _conv.apply
|
||||
torch.manual_seed(0)
|
||||
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad = (1, 1)
|
||||
stride = (1, 1)
|
||||
a = torch.rand((Z, CI, H, W)).cuda()
|
||||
b = torch.rand((CI, R, S, CO)).cuda()
|
||||
time = [None]
|
||||
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
|
||||
c = conv(a, b, pad, stride, time)
|
||||
print((cc - c).abs().max() / max(cc.max(), c.max()))
|
||||
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
|
||||
#zc = torch.matmul(a,b)
|
||||
#zc_ = dot(a,b)
|
70
python/examples/tutorials/mat_copy.py
Normal file
70
python/examples/tutorials/mat_copy.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _copy(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void copy(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
|
||||
|
||||
*py = *px;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N;
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((M,N)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _copy.kernel is None:
|
||||
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
|
||||
|
||||
_copy.kernel(x, y, M, N, ldx, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
copy = _copy.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(8,4).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = x
|
||||
yb = copy(x)
|
||||
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
133
python/examples/tutorials/mat_mul.py
Normal file
133
python/examples/tutorials/mat_mul.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _dot._call(a,b)
|
||||
return c
|
||||
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
'TK' : [8, 16],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
kernel = _dot.kernel[dtype]
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = triton.empty([M,N], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
||||
triton.cdiv(N, opt.d('TN'))]
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
grid=grid, bench=100)
|
||||
print(2*M*N*K/(time*1e-6)*1e-9)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = torch.rand((M, K)).cuda()
|
||||
b = torch.rand((K, N)).cuda()
|
||||
|
||||
|
||||
#zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
#print(torch.allclose(zc, zc_))
|
76
python/examples/tutorials/mat_transpose.py
Normal file
76
python/examples/tutorials/mat_transpose.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _transpose(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8), int ldy __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N
|
||||
ldy = M
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((N,M)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _transpose.kernel is None:
|
||||
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
|
||||
|
||||
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
transpose = _transpose.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1024,128).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = torch.t(x)
|
||||
yb = transpose(x)
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
43
python/examples/tutorials/vec_add.py
Normal file
43
python/examples/tutorials/vec_add.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N) {
|
||||
|
||||
int pid = get_program_id(0);
|
||||
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
|
||||
bool check[TILE] = offset < N;
|
||||
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
z = torch.empty_like(x).cuda()
|
||||
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||
|
||||
_add.kernel(z,x,y, N, grid=grid)
|
||||
|
||||
return z
|
||||
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432).cuda()
|
||||
y = torch.rand(98432).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
|
||||
print(torch.allclose(za,zb))
|
@@ -8,9 +8,11 @@ import distutils
|
||||
import glob
|
||||
from distutils.version import LooseVersion
|
||||
from setuptools import setup, Extension, find_packages
|
||||
from torch.utils.cpp_extension import include_paths, library_paths
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import torch
|
||||
|
||||
|
||||
def find_llvm():
|
||||
@@ -58,12 +60,17 @@ class CMakeBuild(build_ext):
|
||||
# python directories
|
||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
||||
torch_include_dirs = include_paths(True)
|
||||
torch_library_dirs = library_paths(True)
|
||||
abi = torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_TESTS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
|
||||
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
|
||||
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
|
||||
'-DLLVM_CONFIG=' + find_llvm()]
|
||||
# configuration
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
@@ -80,8 +87,6 @@ class CMakeBuild(build_ext):
|
||||
build_args += ['--', '-j4']
|
||||
|
||||
env = os.environ.copy()
|
||||
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
|
||||
self.distribution.get_version())
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
||||
@@ -100,7 +105,7 @@ for d in directories:
|
||||
|
||||
setup(
|
||||
name='triton',
|
||||
version='0.1',
|
||||
version='0.2.3',
|
||||
author='Philippe Tillet',
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
|
@@ -3,57 +3,51 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, double> fp64scalar_map;
|
||||
std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
/* Grid map */
|
||||
|
||||
void register_grid(size_t id,
|
||||
void register_grid(const map_key_t& key,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
}
|
||||
|
||||
void delete_grid(size_t id) {
|
||||
id_grid_map.erase(id);
|
||||
void delete_grid(const map_key_t& key) {
|
||||
id_grid_map.erase(key);
|
||||
}
|
||||
|
||||
/* Function map */
|
||||
|
||||
void register_fn(size_t id,
|
||||
void register_fn(const map_key_t& key,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt,
|
||||
const std::string &cache_ref) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
|
||||
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
|
||||
}
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
id_fn_map.erase(id);
|
||||
void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
}
|
||||
|
||||
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
|
||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||
pybind11::buffer_info info = data.request();
|
||||
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
@@ -82,196 +76,6 @@ int64_t retrieve_scalar(size_t id) {
|
||||
return i64scalar_map.at(id);
|
||||
}
|
||||
|
||||
/* TF source-code generation */
|
||||
|
||||
inline std::string to_tf_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
if(ty->is_integer_ty(8))
|
||||
return "int8";
|
||||
if(ty->is_integer_ty(16))
|
||||
return "int16";
|
||||
if(ty->is_integer_ty(32))
|
||||
return "int32";
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
if(ty->is_float_ty())
|
||||
return "float";
|
||||
if(ty->is_double_ty())
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
inline std::string to_tf_scalar_ty(ir::type *ty) {
|
||||
if(ty->is_pointer_ty())
|
||||
return to_tf_ty(ty->get_pointer_element_ty());
|
||||
else {
|
||||
return to_tf_ty(ty);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string ref_to_tf_ty(ir::type *ty) {
|
||||
std::string res = to_tf_ty(ty);
|
||||
if(ty->is_pointer_ty())
|
||||
res = "const " + res + "&";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string tf_normalize(const std::string& name) {
|
||||
std::string ret = name;
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct tf_alloc_t{
|
||||
enum type_t{
|
||||
OUTPUT,
|
||||
TEMP
|
||||
};
|
||||
|
||||
tf_alloc_t(const std::string& _name, type_t _type)
|
||||
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
|
||||
|
||||
std::string tf_name;
|
||||
std::string name;
|
||||
type_t type;
|
||||
size_t shape_id;
|
||||
};
|
||||
|
||||
typedef std::vector<tf_alloc_t> alloc_map_t;
|
||||
|
||||
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::value *arg = args[i];
|
||||
const std::string& name = arg->get_name();
|
||||
std::string ty = to_tf_ty(arg->get_type());
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
|
||||
else if(std::find_if(allocs.begin(), allocs.end(),
|
||||
[&](tf_alloc_t x) {
|
||||
return x.name == name;
|
||||
}) == allocs.end())
|
||||
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
|
||||
else
|
||||
os << " Tensor* " << name << " = nullptr;\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
|
||||
// initialize shapes
|
||||
for(const auto& x: allocs)
|
||||
os << " TensorShape " << x.name << "_shape;\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
|
||||
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
|
||||
|
||||
// allocate
|
||||
int output = 0;
|
||||
for(const auto& x: allocs){
|
||||
if(x.type == tf_alloc_t::OUTPUT)
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
|
||||
else
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
continue;
|
||||
const std::string& name = arg->get_name();
|
||||
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
if(arg->get_type()->is_pointer_ty())
|
||||
name = "&cu_" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id_), stream);\n ";
|
||||
os << " };\n ";
|
||||
os << " run();\n ";
|
||||
os << " if(bench_ > 0)\n ";
|
||||
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
|
||||
}
|
||||
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const alloc_map_t& allocs){
|
||||
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = tf_normalize(arg->get_name());
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << ".HostMemory(\"" + name + "\")";
|
||||
}
|
||||
for(const auto& x: allocs)
|
||||
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
|
||||
os << ", " + opname << ");\n";
|
||||
}
|
||||
|
||||
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const alloc_map_t& allocs){
|
||||
|
||||
|
||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||
for(size_t i = 0; i < args.size(); i++)
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = tf_normalize(arg->get_name());
|
||||
if(std::find_if(allocs.begin(), allocs.end(),
|
||||
[&](tf_alloc_t x) {
|
||||
return name == x.tf_name;
|
||||
}) == allocs.end())
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
else
|
||||
os << " .Input(\"" << name << "_shape: int32\")\n";
|
||||
}
|
||||
for(const auto& x: allocs)
|
||||
if(x.type == tf_alloc_t::OUTPUT)
|
||||
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
|
||||
os << " .Attr(\"id: int\")\n";
|
||||
os << " .Attr(\"bench: int\")\n";
|
||||
os << " .Attr(\"bench_id: int\")\n";
|
||||
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
|
||||
size_t current = 0;
|
||||
for(const auto& x: allocs)
|
||||
if(x.type == tf_alloc_t::OUTPUT){
|
||||
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
|
||||
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
|
||||
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
|
||||
}
|
||||
os << " return Status::OK();\n";
|
||||
os << " })\n";
|
||||
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
@@ -289,318 +93,6 @@ void make_module(const std::string& src, ir::module* ir,
|
||||
gen.Gen(ir);
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& tmp,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
const std::vector<ir::argument*>& args = fn->args();
|
||||
std::string name = fn->get_name();
|
||||
std::string cc_name = name;
|
||||
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
|
||||
std::string opname = cc_name + "Op";
|
||||
|
||||
// allocation info
|
||||
alloc_map_t allocs;
|
||||
for(size_t i = 0; i < outputs.size(); i++)
|
||||
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
|
||||
for(size_t i = 0; i < tmp.size(); i++)
|
||||
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
|
||||
|
||||
for(auto &x: allocs){
|
||||
size_t idx;
|
||||
for(idx = 0; idx < args.size(); idx++)
|
||||
if(args[idx]->get_name() == x.name)
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
x.shape_id = idx;
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
class )" << opname << R"(: public OpKernel {
|
||||
public:
|
||||
explicit )" << opname << R"((OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
|
||||
)";
|
||||
for(const auto& alloc: allocs)
|
||||
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
|
||||
|
||||
oss << R"(
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
drv::cu_stream sstream(device.stream(), false);
|
||||
drv::context* ctx = sstream.context();
|
||||
drv::stream* stream = &sstream;
|
||||
|
||||
// extract inputs
|
||||
)";
|
||||
gen_extract_inputs(oss, args, allocs);
|
||||
oss << R"(
|
||||
// set outputs
|
||||
)";
|
||||
gen_set_outputs(oss, args, allocs);
|
||||
oss << R"(
|
||||
// wrap tensors
|
||||
)";
|
||||
gen_make_handles(oss, args);
|
||||
oss << R"(
|
||||
)";
|
||||
oss << R"(
|
||||
// launch function
|
||||
)";
|
||||
gen_make_launch_function(oss, args);
|
||||
oss << R"(
|
||||
}
|
||||
|
||||
private:
|
||||
int id_;
|
||||
int bench_;
|
||||
int64 bench_id_;
|
||||
)";
|
||||
for(const auto& alloc: allocs)
|
||||
oss << "DataType " << alloc.name << "_type;\n ";
|
||||
|
||||
oss << R"(
|
||||
};
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
gen_tf_register_op(oss, cc_name, args, allocs);
|
||||
|
||||
return std::tuple<std::string, std::string>{oss.str(), name};
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_torch_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty())
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "double";
|
||||
if(ty->is_float_ty())
|
||||
return "double";
|
||||
if(ty->is_double_ty())
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "torch::Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
inline std::string to_torch_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "int64_t";
|
||||
case rt::INT8_T: return "int64_t";
|
||||
case rt::INT16_T: return "int64_t";
|
||||
case rt::INT32_T: return "int64_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "double";
|
||||
case rt::FLOAT_T: return "double";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "torch::Tensor";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string to_c_ty(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "bool";
|
||||
case rt::INT8_T: return "int8_t";
|
||||
case rt::INT16_T: return "int16_t";
|
||||
case rt::INT32_T: return "int32_t";
|
||||
case rt::INT64_T: return "int64_t";
|
||||
case rt::HALF_T: return "half";
|
||||
case rt::FLOAT_T: return "float";
|
||||
case rt::DOUBLE_T: return "double";
|
||||
case rt::BUFFER_T: return "drv::cu_buffer";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_c_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
if(ty->is_integer_ty(8))
|
||||
return "int8_t";
|
||||
if(ty->is_integer_ty(16))
|
||||
return "int16_t";
|
||||
if(ty->is_integer_ty(32))
|
||||
return "int32_t";
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "half";
|
||||
if(ty->is_float_ty())
|
||||
return "float";
|
||||
if(ty->is_double_ty())
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "drv::cu_buffer";
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
const std::string& name,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
std::string ret_ty = "void";
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
oss << "int64_t bench_id, ";
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(args[i]) << " " << "th_arg_" << i;
|
||||
}
|
||||
oss << ")";
|
||||
}
|
||||
|
||||
void gen_torch_init_driver(std::ostringstream &oss,
|
||||
const std::vector<rt::arg_type>&args) {
|
||||
// Find index of first buffer
|
||||
size_t i;
|
||||
for(i = 0; i < args.size(); i++)
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
break;
|
||||
oss << " // Wrap CUDA handles" << std::endl;
|
||||
oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl;
|
||||
oss << " // Get stream" << std::endl;
|
||||
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
|
||||
}
|
||||
|
||||
void gen_torch_make_handles(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
rt::arg_type arg = args[i];
|
||||
const std::string th_name = "th_arg_" + std::to_string(i);
|
||||
const std::string name = "arg_" + std::to_string(i);
|
||||
if(arg != rt::BUFFER_T)
|
||||
os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl;
|
||||
else{
|
||||
os << " CHECK_INPUT(" << th_name << ");" << std::endl;
|
||||
os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
std::string name = "arg_" + std::to_string(i);
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
name = "&" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
||||
os << " };\n";
|
||||
os << " run();\n";
|
||||
os << " if(bench > 0)\n ";
|
||||
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
||||
}
|
||||
|
||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
if(outputs.size() == 1){
|
||||
os << " return " << outputs[0] << ";" << std::endl;
|
||||
return;
|
||||
}
|
||||
os << " return {";
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << outputs[i];
|
||||
}
|
||||
os << "};" << std::endl;
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& name, std::vector<rt::arg_type> args) {
|
||||
// generate framework code
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x);
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, name, args);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss, args);
|
||||
gen_torch_make_handles(oss, args);
|
||||
gen_torch_make_launch_function(oss, args);
|
||||
//gen_torch_ret(oss);
|
||||
oss << "}" << std::endl;
|
||||
|
||||
oss << std::endl;
|
||||
oss << std::endl;
|
||||
oss << "static auto registry = torch::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
|
||||
|
||||
return std::tuple<std::string, std::string>{oss.str(), name};
|
||||
}
|
||||
|
||||
/* Function signature */
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
@@ -624,13 +116,6 @@ typedef triton::runtime::function::options_space_t options_space_t;
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
|
||||
// framework binding source code generation
|
||||
m.def("make_tensorflow_src", &make_tensorflow_src,
|
||||
"Creates C++ source code for a custom Tensorflow op "
|
||||
"corresponding to the specified Triton kernel");
|
||||
m.def("make_torch_src", &make_torch_src,
|
||||
"Creates C++ source code for a custom PyTorch op ");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
|
30
python/src/launch.cc
Normal file
30
python/src/launch.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x);
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<size_t, size_t> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
}
|
||||
|
||||
|
||||
static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel);
|
@@ -1,7 +1,6 @@
|
||||
from .kernel import *
|
||||
from .utils import *
|
||||
import triton.ops
|
||||
import triton.nn
|
||||
#import triton.ops
|
||||
#import triton.nn
|
||||
|
||||
|
||||
# clean-up libtriton resources
|
||||
|
@@ -1,28 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
torch = None
|
||||
tensorflow = None
|
||||
|
||||
def _import_torch():
|
||||
global torch
|
||||
if torch is None:
|
||||
import torch
|
||||
|
||||
def _import_tensorflow():
|
||||
global tensorflow
|
||||
if tensorflow is None:
|
||||
import tensorflow
|
||||
|
||||
def has_tensorflow():
|
||||
result = 'tensorflow' in sys.modules
|
||||
if result:
|
||||
_import_tensorflow()
|
||||
return result
|
||||
|
||||
def has_torch():
|
||||
result = 'torch' in sys.modules
|
||||
if result:
|
||||
_import_torch()
|
||||
return result
|
@@ -1,199 +1,77 @@
|
||||
# import for cache
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import hashlib
|
||||
import sysconfig
|
||||
import sys
|
||||
import weakref
|
||||
import contextlib
|
||||
import io
|
||||
# import for just-in-time compilation
|
||||
import distutils
|
||||
import setuptools.command.build_ext
|
||||
import setuptools
|
||||
# triton
|
||||
import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import triton._C.libtriton as libtriton
|
||||
import os
|
||||
import time
|
||||
from struct import pack
|
||||
import torch
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
codes = {
|
||||
libtriton.arg_type.int1: 'B',
|
||||
libtriton.arg_type.int8: 'B',
|
||||
libtriton.arg_type.int32: 'I',
|
||||
libtriton.arg_type.int64: 'Q',
|
||||
libtriton.arg_type.half: 'H',
|
||||
libtriton.arg_type.float: 'f',
|
||||
libtriton.arg_type.double: 'd',
|
||||
libtriton.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
def _build(src, path, name):
|
||||
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||
ccdir = os.path.realpath(ccdir)
|
||||
# include directories
|
||||
triton_include_dirs = [os.path.join(ccdir, 'include')]
|
||||
include_dirs = triton_include_dirs
|
||||
# library directories
|
||||
triton_library_dirs = [ccdir]
|
||||
library_dirs = triton_library_dirs
|
||||
# libraries
|
||||
libraries = ['triton']
|
||||
# add framework
|
||||
extra_compile_args = []
|
||||
if fw.has_torch():
|
||||
prefix = os.path.dirname(fw.torch.__file__)
|
||||
library_dirs += [os.path.join(prefix, 'lib')]
|
||||
include_dirs += ['/usr/local/cuda/include/',
|
||||
os.path.join(prefix, 'lib', 'include'),
|
||||
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
|
||||
os.path.join(prefix, 'include'),
|
||||
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
|
||||
libraries += ['torch']
|
||||
abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
else:
|
||||
assert False
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# dependences
|
||||
depends = [os.path.realpath(libtriton.__file__)]
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name = name,
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
extra_compile_args = extra_compile_args + ['-g0'],
|
||||
extra_link_args = extra_link_args,
|
||||
library_dirs = library_dirs,
|
||||
libraries = libraries,
|
||||
depends = depends
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
tmp = tempfile.mkdtemp()
|
||||
args.append('--build-temp=' + tmp)
|
||||
args.append('--build-lib=' + path)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name = name,
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def _cvt_to_def_str(obj):
|
||||
# bool
|
||||
if isinstance(obj, bool):
|
||||
return str(int(obj))
|
||||
# torch type
|
||||
if fw.has_torch():
|
||||
if isinstance(obj, fw.torch.dtype):
|
||||
return {fw.torch.int8: 'char',
|
||||
fw.torch.int16: 'short',
|
||||
fw.torch.int32: 'int',
|
||||
fw.torch.int64: 'long',
|
||||
fw.torch.float16: 'half',
|
||||
fw.torch.float32: 'float',
|
||||
fw.torch.float64: 'double'}[obj]
|
||||
else:
|
||||
assert False
|
||||
# default
|
||||
return str(obj)
|
||||
sizes = {
|
||||
libtriton.arg_type.int1: 1,
|
||||
libtriton.arg_type.int8: 1,
|
||||
libtriton.arg_type.int32: 4,
|
||||
libtriton.arg_type.int64: 8,
|
||||
libtriton.arg_type.half: 2,
|
||||
libtriton.arg_type.float: 4,
|
||||
libtriton.arg_type.double: 8,
|
||||
libtriton.arg_type.buffer: 8
|
||||
}
|
||||
|
||||
|
||||
def _encode(arg_types):
|
||||
codes = {
|
||||
libtriton.arg_type.int1: 'i1',
|
||||
libtriton.arg_type.int8: 'i8',
|
||||
libtriton.arg_type.int32: 'i32',
|
||||
libtriton.arg_type.int64: 'i64',
|
||||
libtriton.arg_type.half: 'f16',
|
||||
libtriton.arg_type.float: 'f32',
|
||||
libtriton.arg_type.double: 'f64',
|
||||
libtriton.arg_type.buffer: 'buf'
|
||||
def th_to_triton(obj):
|
||||
tys = {
|
||||
torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double'
|
||||
}
|
||||
ret = '_'.join(map(codes.get, arg_types))
|
||||
return ret
|
||||
if isinstance(obj, torch.dtype):
|
||||
return [tys[obj]]
|
||||
if isinstance(obj, list):
|
||||
return [th_to_triton(x)[0] for x in obj]
|
||||
return [str(obj)]
|
||||
|
||||
def _make_framework_op(arg_types):
|
||||
name = _encode(arg_types)
|
||||
# path of .cpp and .so file
|
||||
home = os.path.expanduser('~')
|
||||
root = os.path.join(home, '.triton', 'torch', name)
|
||||
try:
|
||||
os.makedirs(root)
|
||||
except FileExistsError:
|
||||
pass
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, f'op{suffix}')
|
||||
cpp = os.path.join(root, f'op.cpp')
|
||||
# handle cached .so file
|
||||
if os.path.exists(so) and os.stat(so).st_size > 0:
|
||||
tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime
|
||||
so_mtime = os.stat(so).st_mtime
|
||||
# can use cached if libtriton is older than the .so
|
||||
if tt_mtime < so_mtime:
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
# create torch source code
|
||||
lock = os.path.join(root, f'lock')
|
||||
try:
|
||||
fd = os.open(lock, os.O_CREAT|os.O_EXCL|os.O_RDWR)
|
||||
if os.path.exists(so):
|
||||
fw.torch.ops.load_library(so)
|
||||
os.remove(lock)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
src, _ = libtriton.make_torch_src(name, arg_types)
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# create torch.so
|
||||
_build(cpp, root, 'op')
|
||||
fw.torch.ops.load_library(so)
|
||||
os.remove(lock)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
except FileExistsError:
|
||||
# spin until .so is fully written
|
||||
while os.path.exists(lock):
|
||||
time.sleep(0.01)
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
|
||||
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||
self.src = src
|
||||
# create constants
|
||||
self.cst = dict()
|
||||
# create triton op
|
||||
macros = []
|
||||
for k, v in defines.items():
|
||||
cvt = lambda x: _cvt_to_def_str(x)
|
||||
if(isinstance(v, list)):
|
||||
values = list(map(cvt, v))
|
||||
else:
|
||||
values = [cvt(v)]
|
||||
macros.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = macros
|
||||
opt.num_warps = num_warps
|
||||
self.opt = libtriton.options_space()
|
||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
||||
self.opt.num_warps = num_warps
|
||||
self.op_id = libtriton.make_op_id()
|
||||
libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__))
|
||||
# create pytorch hook
|
||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||
self.fw_op = _make_framework_op(arg_types)
|
||||
self.registered = set()
|
||||
arg_types = libtriton.get_fn_signature(self.src, self.opt)
|
||||
size = sum([sizes[x] for x in arg_types])
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
|
||||
def set_constant(self, name, value):
|
||||
libtriton.register_cst(self.op_id, name, value)
|
||||
def set_constant(self, device, name, value):
|
||||
libtriton.register_cst((self.op_id, device), name, value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
device = x.device.index
|
||||
break
|
||||
# lazily register function for device
|
||||
if device not in self.registered:
|
||||
self.registered.add(device)
|
||||
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
|
||||
# launch options
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
@@ -201,8 +79,8 @@ class kernel:
|
||||
if 'grid' not in kwargs:
|
||||
raise RuntimeError('Must provide grid for kernel launch')
|
||||
grid = kwargs['grid']
|
||||
libtriton.register_grid(self.op_id, grid)
|
||||
libtriton.register_grid((self.op_id, device), grid)
|
||||
# launch
|
||||
self.fw_op(self.op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
return libtriton.retrieve_scalar(bench_id)
|
||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||
torch.cuda.synchronize()
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
@@ -1,75 +0,0 @@
|
||||
import triton.frameworks as fw
|
||||
import triton._C.libtriton as libtriton
|
||||
import numpy as np
|
||||
import weakref
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
class tf_empty_proxy:
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.tensor = None
|
||||
|
||||
def to_tensor(self):
|
||||
assert self.tensor is not None
|
||||
return self.tensor
|
||||
|
||||
def empty(shape, dtype):
|
||||
if fw.has_tensorflow():
|
||||
shape = [fw.tensorflow.constant(x) for x in shape]
|
||||
shape = fw.tensorflow.stack(shape)
|
||||
return tf_empty_proxy(shape, dtype)
|
||||
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
return fw.torch.empty(shape, dtype=dtype, device='cuda:0')
|
||||
|
||||
def shape(A) :
|
||||
if fw.has_tensorflow():
|
||||
return A.shape.as_list()
|
||||
elif fw.has_torch():
|
||||
return A.shape
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class id_dict:
|
||||
|
||||
# Lazy entry for e.g., tensorflow, when value of benchmark is
|
||||
# not known at graph compile time
|
||||
class lazy_entry:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def get(self):
|
||||
return libtriton.retrieve_scalar(self.id)
|
||||
|
||||
def __init__(self):
|
||||
self.data = dict()
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[key]
|
||||
|
||||
@staticmethod
|
||||
def _get_key(key):
|
||||
if fw.has_tensorflow():
|
||||
if isinstance(key, fw.tensorflow.Tensor):
|
||||
key = id(key.op)
|
||||
if fw.has_torch():
|
||||
if isinstance(key, fw.torch.Tensor):
|
||||
key = id(key)
|
||||
return key
|
||||
|
||||
def __getitem__(self, key):
|
||||
ret = self.data[id_dict._get_key(key)]
|
||||
if isinstance(ret, id_dict.lazy_entry):
|
||||
return ret.get()
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[id_dict._get_key(key)] = value
|
@@ -1,4 +1,4 @@
|
||||
foreach(PROG dot copy)
|
||||
foreach(PROG dot copy conv)
|
||||
set(TARGET bench_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
|
39
tests/bench/conv.cc
Normal file
39
tests/bench/conv.cc
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "conv.h"
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1},
|
||||
|
||||
// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||
// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||
{1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1},
|
||||
// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1}
|
||||
|
||||
|
||||
|
||||
|
||||
};
|
||||
int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w;
|
||||
for(const auto& c: configs){
|
||||
std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c;
|
||||
std::cout << "// " << c ;
|
||||
for(auto perf: bench_conv(stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
@@ -9,11 +9,40 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
for(auto ord: std::vector<std::vector<int>>{{0, 1}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
// config_t{ord, x[0], x[1], 128, 128, 128},
|
||||
// config_t{ord, x[0], x[1], 256, 256, 256},
|
||||
// config_t{ord, x[0], x[1], 384, 384, 384},
|
||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||
// config_t{ord, x[0], x[1], 768, 768, 768},
|
||||
config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
||||
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
||||
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
||||
|
||||
// config_t{ord, x[0], x[1], 256, 16, 256},
|
||||
// config_t{ord, x[0], x[1], 512, 16, 512},
|
||||
// config_t{ord, x[0], x[1], 768, 16, 768},
|
||||
// config_t{ord, x[0], x[1], 1024, 16, 1024},
|
||||
// config_t{ord, x[0], x[1], 1280, 16, 1280},
|
||||
// config_t{ord, x[0], x[1], 1536, 16, 1536},
|
||||
// config_t{ord, x[0], x[1], 2048, 16, 2048},
|
||||
// config_t{ord, x[0], x[1], 3072, 16, 3072},
|
||||
// config_t{ord, x[0], x[1], 4096, 16, 4096},
|
||||
// config_t{ord, x[0], x[1], 5120, 16, 5120},
|
||||
// config_t{ord, x[0], x[1], 6144, 16, 6144},
|
||||
// config_t{ord, x[0], x[1], 7168, 16, 7168},
|
||||
|
||||
// config_t{ord, x[0], x[1], 64, 64, 4096},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 8192},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 16384},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 32768},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 65536},
|
||||
// config_t{ord, x[0], x[1], 64, 64, 131072}
|
||||
|
||||
// config_t{ord, x[0], x[1], 127008, 768, 576},
|
||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||
|
150
tests/common/conv.h
Normal file
150
tests/common/conv.h
Normal file
@@ -0,0 +1,150 @@
|
||||
#include <iomanip>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "src/conv.h"
|
||||
#include "cuda/cublas.h"
|
||||
#include "util.h"
|
||||
|
||||
|
||||
struct conv_arg_t{
|
||||
CUdeviceptr a;
|
||||
CUdeviceptr b;
|
||||
CUdeviceptr c;
|
||||
float alpha;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int pad_h;
|
||||
int pad_w;
|
||||
int stride_h;
|
||||
int stride_w;
|
||||
CUdeviceptr adelta;
|
||||
int lda_z;
|
||||
int lda_ci;
|
||||
int lda_h;
|
||||
int lda_w;
|
||||
int ldb_ci;
|
||||
int ldb_r;
|
||||
int ldb_s;
|
||||
int ldb_co;
|
||||
int ldc_z;
|
||||
int ldc_co;
|
||||
int ldc_p;
|
||||
int ldc_q;
|
||||
};
|
||||
|
||||
enum run_mode_t {
|
||||
BENCH,
|
||||
TEST
|
||||
};
|
||||
|
||||
enum dtype_t {
|
||||
FLOAT,
|
||||
HALF,
|
||||
DOUBLE
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct to_string;
|
||||
|
||||
template<> struct to_string<half_float::half>{
|
||||
static constexpr const char* value = "half";
|
||||
};
|
||||
|
||||
template<> struct to_string<float>{
|
||||
static constexpr const char* value = "float";
|
||||
};
|
||||
|
||||
template<> struct to_string<double>{
|
||||
static constexpr const char* value = "double";
|
||||
};
|
||||
|
||||
template<class T>
|
||||
void triton_conv(drv::stream* stream,
|
||||
int Z, int CI, int H, int W, int CO, int R, int S,
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
run_mode_t mode, std::vector<double>& bench, bool &test){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::context* context = stream->context();
|
||||
|
||||
int P = (H + 2*pad_h - R)/stride_h + 1;
|
||||
int Q = (W + 2*pad_w - S)/stride_w + 1;
|
||||
|
||||
// inputs
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CI*H*W*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*CO*dt_nbytes));
|
||||
auto ddelta = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*4));
|
||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||
|
||||
std::vector<int32_t> hdelta(CI*R*S);
|
||||
int TK = 16;
|
||||
for(int i = 0; i < hdelta.size(); i++){
|
||||
int s = i % S;
|
||||
int cr = i / S;
|
||||
int r = cr % R;
|
||||
int c = cr / R;
|
||||
int nexti = i + TK;
|
||||
int nexts = nexti % S;
|
||||
int nextcr = nexti / S;
|
||||
int nextr = nextcr % R;
|
||||
int nextc = nextcr / R;
|
||||
hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s);
|
||||
}
|
||||
stream->write(&*ddelta, true, 0, hdelta);
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.defines.push_back({"TZ", {"1"}});
|
||||
opt.defines.push_back({"RR", {std::to_string(R)}});
|
||||
opt.defines.push_back({"SS", {std::to_string(S)}});
|
||||
opt.defines.push_back({"PP", {std::to_string(P)}});
|
||||
opt.defines.push_back({"QQ", {std::to_string(Q)}});
|
||||
opt.defines.push_back({"HH", {std::to_string(H)}});
|
||||
opt.defines.push_back({"WW", {std::to_string(W)}});
|
||||
|
||||
opt.num_warps = {2, 4};
|
||||
|
||||
// kernels
|
||||
rt::function function(src::conv, opt);
|
||||
conv_arg_t args{*da->cu(), *db->cu(), *dc->cu(), 1, Z*P*Q, CO, CI*R*S,
|
||||
pad_h, pad_w, stride_h, stride_w,
|
||||
*ddelta->cu(),
|
||||
W*H*CI, W*H, W, 1,
|
||||
CO*S*R , CO*S, CO, 1,
|
||||
Q*P*CO, Q*P, Q, 1};
|
||||
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
||||
ceil(CO , x.D<int>("TN")),
|
||||
(size_t)x.D<int>("TZ")};
|
||||
};
|
||||
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
|
||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
}
|
||||
|
||||
std::vector<double> bench_conv(drv::stream* stream, dtype_t dtype,
|
||||
int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S,
|
||||
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) {
|
||||
std::vector<double> bench;
|
||||
bool test;
|
||||
switch(dtype){
|
||||
case HALF: triton_conv<half_float::half>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||
case FLOAT: triton_conv<float>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||
case DOUBLE: triton_conv<double>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||
default: break;
|
||||
}
|
||||
return bench;
|
||||
}
|
@@ -12,6 +12,15 @@ int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides
|
||||
return res;
|
||||
}
|
||||
|
||||
struct copy_arg_t{
|
||||
CUdeviceptr X;
|
||||
CUdeviceptr Y;
|
||||
int S0;
|
||||
int S1;
|
||||
int S2;
|
||||
};
|
||||
|
||||
|
||||
enum run_mode_t {
|
||||
BENCH,
|
||||
TEST
|
||||
@@ -115,16 +124,16 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
||||
|
||||
// kernel
|
||||
rt::function function(src::copy_nd[rank - 1], opt);
|
||||
std::vector<rt::arg> args = {&*dx, &*dy};
|
||||
for(int32_t d: shape)
|
||||
args.push_back(d);
|
||||
copy_arg_t args = {*dx->cu(), *dy->cu(), shape[0]};
|
||||
if(shape.size() > 1) args.S1 = shape[1];
|
||||
if(shape.size() > 2) args.S2 = shape[2];
|
||||
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||
auto grid = grid_nd(shape, ts);
|
||||
|
||||
// metrics
|
||||
if(mode == BENCH){
|
||||
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||
bench.push_back(gbps(triton_ns));
|
||||
}
|
||||
|
||||
@@ -136,7 +145,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
||||
for(size_t i = 0; i < hx.size(); i++)
|
||||
hx[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||
stream->write(&*dx, true, 0, hx);
|
||||
function(args, grid, stream);
|
||||
function((void**)&args, sizeof(args), grid, stream);
|
||||
stream->synchronize();
|
||||
stream->read(&*dy, true, 0, hy);
|
||||
cc_copy_nd(hx, ry, shape, x_order, y_order);
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
@@ -12,6 +13,20 @@
|
||||
#include "util.h"
|
||||
|
||||
|
||||
struct dot_arg_t{
|
||||
CUdeviceptr a;
|
||||
CUdeviceptr b;
|
||||
CUdeviceptr c;
|
||||
float alpha;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int lda;
|
||||
int ldb;
|
||||
int ldc;
|
||||
CUdeviceptr locks;
|
||||
};
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
@@ -64,7 +79,7 @@ template<> struct to_string<double>{
|
||||
};
|
||||
|
||||
template<class T>
|
||||
bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
void triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, int32_t nwarp,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order,
|
||||
@@ -79,9 +94,11 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
std::vector<std::string> sb = { "1", "ldb" };
|
||||
|
||||
// inputs
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
@@ -106,33 +123,41 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.defines.push_back({"TZ", {"1"}});
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"32"}});
|
||||
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.defines.push_back({"TZ", {"1"}});
|
||||
opt.num_warps = {4};
|
||||
}
|
||||
|
||||
// kernels
|
||||
rt::function function(src::dot, opt);
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
|
||||
auto grid = grid2d(M, N);
|
||||
dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(),
|
||||
1, M, N, K, lda, ldb, ldc, *dlocks->cu()};
|
||||
|
||||
auto grid = [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN")),
|
||||
(size_t)x.D<int>("TZ")};
|
||||
};
|
||||
|
||||
// metrics
|
||||
if(mode == BENCH){
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
|
||||
// // cublas
|
||||
// cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// T alpha(static_cast<double>(1));
|
||||
// T beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// bench.push_back(tflops(cublas_ms));
|
||||
@@ -154,7 +179,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
// run kernel
|
||||
function(args, grid, stream);
|
||||
function((void**)&args, sizeof(args), grid, stream);
|
||||
// write back
|
||||
stream->synchronize();
|
||||
// compare with CPU
|
||||
|
@@ -13,6 +13,15 @@
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
struct reduce_arg_t{
|
||||
CUdeviceptr X;
|
||||
CUdeviceptr Y;
|
||||
int S0;
|
||||
int S1;
|
||||
int S2;
|
||||
};
|
||||
|
||||
|
||||
template<class T>
|
||||
void cc_reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
|
||||
assert(axis <= shapes.size() - 1);
|
||||
@@ -123,16 +132,16 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
|
||||
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
|
||||
|
||||
// grid
|
||||
std::vector<rt::arg> args = {&*dx, &*dy};
|
||||
for(int32_t d: shape_x)
|
||||
args.push_back(d);
|
||||
reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]};
|
||||
if(shape_x.size() > 1) args.S1 = shape_x[1];
|
||||
if(shape_x.size() > 2) args.S2 = shape_x[2];
|
||||
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||
auto grid = grid_nd(shape_x, ts);
|
||||
|
||||
// metrics
|
||||
if(mode == BENCH){
|
||||
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||
bench.push_back(gbps(triton_ns));
|
||||
}
|
||||
|
||||
@@ -144,7 +153,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
|
||||
init_zeros(hy);
|
||||
init_rand(hx);
|
||||
stream->write(&*dx, true, 0, hx);
|
||||
function(args, grid, stream);
|
||||
function((void**)&args, sizeof(args), grid, stream);
|
||||
stream->synchronize();
|
||||
stream->read(&*dy, true, 0, hy);
|
||||
cc_reduce_nd(ry, hx, op, axis, shape_x);
|
||||
|
134
tests/common/src/conv.h
Normal file
134
tests/common/src/conv.h
Normal file
@@ -0,0 +1,134 @@
|
||||
namespace src {
|
||||
|
||||
const char *conv =
|
||||
R"(
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs [TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr [TK] = rcir % RR;
|
||||
int rci [TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||
rci[newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr [:, newaxis] * ldb_r +
|
||||
rs [:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
// pointers to delta
|
||||
int* padelta[TK] = ADELTA + rk;
|
||||
int adelta[TK] = *padelta;
|
||||
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkbn[TK, TN] = rn[newaxis, :] < N;
|
||||
bool checkbk[TK, TN] = rk[:, newaxis] < K;
|
||||
bool checkb[TK, TN] = checkbn && checkbk;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
// increment A
|
||||
pa += adelta[newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkbk = k > TK;
|
||||
bool checkb[TK, TN] = checkbn && checkbk;
|
||||
a = checka ? *pa : 0;
|
||||
b = *?(checkb)pb;
|
||||
padelta += TK;
|
||||
adelta = *padelta;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co +
|
||||
rp [:, newaxis] * ldc_p +
|
||||
rm [:, newaxis] * 1;
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
)";
|
||||
|
||||
}
|
@@ -7,7 +7,7 @@ namespace src {
|
||||
R"(
|
||||
void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
||||
TYPE * Y __noalias __readonly __aligned(16),
|
||||
int S0) {
|
||||
int S0 __retune) {
|
||||
int pid0 = get_program_id(0);
|
||||
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||
TYPE* px[TS0] = X + rs0;
|
||||
@@ -20,8 +20,8 @@ void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
||||
R"(
|
||||
void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
||||
TYPE * Y __noalias __writeonly __aligned(16),
|
||||
int S0 __multipleof(8),
|
||||
int S1 __multipleof(8)) {
|
||||
int S0 __multipleof(8) __retune,
|
||||
int S1 __multipleof(8) __retune) {
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||
@@ -37,9 +37,9 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
||||
R"(
|
||||
void copy3d(TYPE * X __noalias __readonly __aligned(16),
|
||||
TYPE * Y __noalias __writeonly __aligned(16),
|
||||
int S0 __multipleof(8),
|
||||
int S1 __multipleof(8),
|
||||
int S2 __multipleof(8)) {
|
||||
int S0 __multipleof(8) __retune,
|
||||
int S1 __multipleof(8) __retune,
|
||||
int S2 __multipleof(8) __retune) {
|
||||
// program id
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
|
@@ -6,13 +6,17 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M, int N, int K,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
int ldc __multipleof(8),
|
||||
int* locks) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
@@ -20,7 +24,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
@@ -35,9 +42,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE b[SHAPE_B] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float c[TM, TN] = 0;
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
c += USEA @ USEB;
|
||||
acc += USEA @ USEB;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
@@ -45,7 +52,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
//c = c * alpha;
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
@@ -53,7 +61,22 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
*?(checkc)pc = (TYPE[TM, TN])c;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
)";
|
||||
|
||||
|
@@ -16,7 +16,7 @@ int main() {
|
||||
for(int nwarps: std::vector<int>{4})
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true}){
|
||||
configs.push_back(config_t{HALF, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
|
||||
configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
|
||||
}
|
||||
// test
|
||||
dtype_t dtype;
|
||||
|
Reference in New Issue
Block a user