From 269ebc12e5f96355e5ebee510cf882ceed73502f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 29 Jan 2021 17:27:16 -0500 Subject: [PATCH] [PYTHON][TESTS][DOC] Various improvement of the API and code quality: * Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops --- CMakeLists.txt | 11 +- LICENSE | 2 +- README.md | 19 +- docs/index.rst | 1 - docs/tutorials/custom-operation.rst | 102 --- docs/tutorials/index.rst | 10 - docs/tutorials/matrix-multiplication.rst | 186 ------ docs/tutorials/matrix-transposition.rst | 174 ----- docs/tutorials/putting-it-all-together.rst | 180 ------ docs/tutorials/triton-vs-cuda.rst | 109 ---- include/triton/driver/stream.h | 6 +- include/triton/runtime/function.h | 153 ++--- lib/codegen/analysis/layout.cc | 1 - lib/driver/stream.cc | 4 +- lib/lang/parser.cc | 2 - lib/runtime/function.cc | 610 ++++++------------ python/examples/tutorials/add.py | 39 -- python/examples/tutorials/copy.py | 70 -- python/examples/tutorials/matmul.py | 143 ---- python/examples/tutorials/trans.py | 76 --- python/setup.py | 23 +- python/src/bindings.cc | 149 +++-- python/src/launch.cc | 95 --- python/src/torch/launch.cc | 83 +++ python/src/torch/superblock.cc | 117 ++++ python/tests/test_blocksparse.py | 50 ++ python/tests/test_conv.py | 17 + python/tests/test_matmul.py | 96 +++ python/tests/test_softmax.py | 8 + python/triton/__init__.py | 11 +- python/triton/kernel.py | 117 ++-- python/triton/ops/__init__.py | 4 + python/triton/ops/blocksparse/__init__.py | 1 + python/triton/ops/blocksparse/matmul.c | 198 ++++++ python/triton/ops/blocksparse/matmul.py | 467 ++++++++++++++ .../tutorials/conv2d.py => triton/ops/conv.c} | 82 +-- python/triton/ops/conv.py | 57 ++ .../src/dot.h => python/triton/ops/matmul.c | 60 +- python/triton/ops/matmul.py | 80 +++ python/triton/ops/softmax.c | 8 + python/triton/ops/softmax.py | 27 + python/tutorials/01-vector-add.ipynb | 335 ++++++++++ tests/CMakeLists.txt | 3 - tests/bench/conv.cc | 39 -- tests/bench/copy.cc | 36 -- tests/bench/copy1d.cc | 54 -- tests/bench/dot.cc | 72 --- tests/common/conv.h | 140 ---- tests/common/copy.h | 186 ------ tests/common/cuda/cublas.h | 218 ------- tests/common/cuda/forward.h | 105 --- tests/common/dot.h | 215 ------ tests/common/reduce.h | 173 ----- tests/common/src/conv.h | 134 ---- tests/common/src/copy.h | 67 -- tests/common/src/reduce.h | 58 -- tests/common/util.h | 225 ------- tests/unit/CMakeLists.txt | 6 - tests/unit/copy.cc | 60 -- tests/unit/dot.cc | 70 -- tests/unit/reduce.cc | 45 -- tutorials/01-matmul.cc | 245 +++++++ {tests/bench => tutorials}/CMakeLists.txt | 4 +- 63 files changed, 2255 insertions(+), 3883 deletions(-) delete mode 100644 docs/tutorials/custom-operation.rst delete mode 100644 docs/tutorials/index.rst delete mode 100644 docs/tutorials/matrix-multiplication.rst delete mode 100644 docs/tutorials/matrix-transposition.rst delete mode 100644 docs/tutorials/putting-it-all-together.rst delete mode 100644 docs/tutorials/triton-vs-cuda.rst delete mode 100644 python/examples/tutorials/add.py delete mode 100644 python/examples/tutorials/copy.py delete mode 100644 python/examples/tutorials/matmul.py delete mode 100644 python/examples/tutorials/trans.py delete mode 100644 python/src/launch.cc create mode 100644 python/src/torch/launch.cc create mode 100644 python/src/torch/superblock.cc create mode 100644 python/tests/test_blocksparse.py create mode 100644 python/tests/test_conv.py create mode 100644 python/tests/test_matmul.py create mode 100644 python/tests/test_softmax.py create mode 100644 python/triton/ops/__init__.py create mode 100644 python/triton/ops/blocksparse/__init__.py create mode 100644 python/triton/ops/blocksparse/matmul.c create mode 100644 python/triton/ops/blocksparse/matmul.py rename python/{examples/tutorials/conv2d.py => triton/ops/conv.c} (63%) create mode 100644 python/triton/ops/conv.py rename tests/common/src/dot.h => python/triton/ops/matmul.c (62%) create mode 100644 python/triton/ops/matmul.py create mode 100644 python/triton/ops/softmax.c create mode 100644 python/triton/ops/softmax.py create mode 100644 python/tutorials/01-vector-add.ipynb delete mode 100644 tests/CMakeLists.txt delete mode 100644 tests/bench/conv.cc delete mode 100644 tests/bench/copy.cc delete mode 100644 tests/bench/copy1d.cc delete mode 100644 tests/bench/dot.cc delete mode 100644 tests/common/conv.h delete mode 100644 tests/common/copy.h delete mode 100644 tests/common/cuda/cublas.h delete mode 100644 tests/common/cuda/forward.h delete mode 100644 tests/common/dot.h delete mode 100644 tests/common/reduce.h delete mode 100644 tests/common/src/conv.h delete mode 100644 tests/common/src/copy.h delete mode 100644 tests/common/src/reduce.h delete mode 100644 tests/common/util.h delete mode 100644 tests/unit/CMakeLists.txt delete mode 100644 tests/unit/copy.cc delete mode 100644 tests/unit/dot.cc delete mode 100644 tests/unit/reduce.cc create mode 100644 tutorials/01-matmul.cc rename {tests/bench => tutorials}/CMakeLists.txt (75%) diff --git a/CMakeLists.txt b/CMakeLists.txt index fddc6f1f6..a293f61ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ include(CTest) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Options -option(BUILD_TESTS "Build C++ Triton tests" ON) +option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) # LLVM @@ -25,16 +25,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=defa # Tests if(BUILD_TESTS) - message(STATUS "Adding C++ tests") - add_subdirectory(tests) + message(STATUS "Adding C++ tutorials") + add_subdirectory(tutorials) endif() # Python module if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file - set(PYTHON_SRC bindings.cc launch.cc) - set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") + set(TORCH_SRC torch/launch.cc torch/superblock.cc) + set(PYTHON_SRC bindings.cc ${TORCH_SRC}) + set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") include_directories("." ${PYTHON_INCLUDE_DIRS}) link_directories(${PYTHON_LINK_DIRS}) endif() diff --git a/LICENSE b/LICENSE index 24d8ee2c6..b08bf852a 100755 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -/* Copyright 2018-2020 Philippe Tillet +/* Copyright 2018-2021 Philippe Tillet * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files diff --git a/README.md b/README.md index f03181354..f2b27b8d3 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,11 @@ This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. -The main components of Triton at the moment are: -- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level (think numpy-like array operations in a C-like language). -- **Triton-IR**: A special-purpose intermediate representation (Triton-IR) for aiding array-level program analysis and optimizations in Triton-C programs. -- **Triton-JIT**: An optimizing just-in-time compiler for Triton-IR, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)). This includes transparent support for mixed-precision and Tensor Cores. - -Bindings for **automatic** PyTorch custom op generations are included in **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc. - -The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work! +The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work! ## Installation -Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM-10+ for code generation. - You can install the latest release with pip as follows: ``` sudo apt-get install llvm-10-dev @@ -30,6 +21,7 @@ or the latest development version with: for the C++ package: ``` git clone https://github.com/ptillet/triton.git; +cd triton; mkdir build; cd build; cmake ../; @@ -39,9 +31,4 @@ make -j8; ## Getting Started -Please visit the [documentation](https://docs.triton-lang.org) to get started with Triton - - -## Contributing - -Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute! \ No newline at end of file +You can find tutorials for Triton for [Python](https://github.com/ptillet/triton/tree/master/tutorials) and [C++](https://github.com/ptillet/triton/tree/master/python/tutorials). \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index cad9f3bb0..d82205b4c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,7 +11,6 @@ Welcome to Triton's documentation! :caption: Contents: installation/index - tutorials/index Indices and tables diff --git a/docs/tutorials/custom-operation.rst b/docs/tutorials/custom-operation.rst deleted file mode 100644 index 28c724ede..000000000 --- a/docs/tutorials/custom-operation.rst +++ /dev/null @@ -1,102 +0,0 @@ -=========================== -Writing a Custom Operation -=========================== - --------------- -Compute Kernel --------------- - -Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following: - -.. code-block:: C - - // Triton - // launch on a grid of (N + TILE - 1) / TILE programs - __global__ void add(float* z, float* x, float* y, int N){ - // program id - int pid = get_program_id(0); - // create arrays of pointers - int offset[TILE] = pid * TILE + 0 ... TILE; - float* pz[TILE] = z + offset; - float* px[TILE] = x + offset; - float* py[TILE] = y + offset; - // bounds checking - bool check[TILE] = offset < N; - // write-back - *?(check)pz = *?(check)px + *?(check)py; - } - -As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch. - ---------------- -PyTorch Wrapper ---------------- - -As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code. - -.. code-block:: python - - import torch - import triton - - class _add(torch.autograd.Function): - # source-code for Triton compute kernel - src = """ - __global__ void add(float* z, float* x, float* y, int N){ - // program id - int pid = get_program_id(0); - // create arrays of pointers - int offset[TILE] = pid * TILE + 0 ... TILE; - float* pz[TILE] = z + offset; - float* px[TILE] = x + offset; - float* py[TILE] = y + offset; - // bounds checking - bool check[TILE] = offset < N; - // write-back - *?(check)pz = *?(check)px + *?(check)py; - } - """ - # create callable kernel for the source-code - # options: 4 warps and a -DTILE=1024 - kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4]) - - # Forward pass - @staticmethod - def forward(ctx, x, y): - # type checking - assert x.dtype == torch.float32 - # allocate output - z = torch.empty_like(x).cuda() - # create launch grid - # this is a function of the launch parameters - # triton.cdiv indicates ceil division - N = x.numel() - grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), ) - # launch kernel - _add.kernel(z, x, y, N, grid = grid) - # return output - return z - - # get callable from Triton function - add = _add.apply - - # test - torch.manual_seed(0) - x = torch.rand(98432).cuda() - y = torch.rand(98432).cuda() - za = x + y - zb = add(x, y) - diff = (za - zb).abs().max() - print(diff) - print(torch.allclose(za,zb)) - -Executing the above code will: - -- Generate a .cpp file containing PyTorch bindings for the Triton function -- Compile this .cpp file using distutils -- Cache the resulting custom op -- Call the resulting custom op - -In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation. - -A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst deleted file mode 100644 index 1cd7548ce..000000000 --- a/docs/tutorials/index.rst +++ /dev/null @@ -1,10 +0,0 @@ -Tutorials -========== - -.. toctree:: - :maxdepth: 1 - - custom-operation - triton-vs-cuda - matrix-transposition - matrix-multiplication diff --git a/docs/tutorials/matrix-multiplication.rst b/docs/tutorials/matrix-multiplication.rst deleted file mode 100644 index fbdd3eabb..000000000 --- a/docs/tutorials/matrix-multiplication.rst +++ /dev/null @@ -1,186 +0,0 @@ -********************* -Matrix Multiplication -********************* - -The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see `CUTLASS `_). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores. - -*Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops.* - -============== -Compute Kernel -============== - -Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below: - -.. code-block:: C - - // Triton-C - // launched on a grid of (M / TM) x (N / TN) programs - __global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K, - int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) { - // prologue - int pm = get_program_id(0); //(1) - int pn = get_program_id(1); //(2) - int rm[TM] = pm * TM + 0 ... TM; //(3) - int rn[TN] = pn * TN + 0 ... TN; //(4) - int rk[TK] = 0 ... TK; //(5) - // initialize accumulator - float c[TM, TN] = 0; //(6) - // pointers to operands - TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7) - TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8) - // reduction loop - for(int k = K; k > 0; k-= TK){ - // fetch operands - TYPE a[TM, TK] = *pa; //(9) - TYPE b[TK, TN] = *pb; //(10) - // matrix-multiply accumulate - c += a @ b; //(11) - // increment pointers - pa = pa + TK * 1; //(12) - pb = pb + TK * ldb; //(13) - } - // epilogue - TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14) - *pc = c; //(15) - } - -Here, each kernel instance produces a :code:`TM x TN` tile of the output matrix C as follows: - -- Statements (1) - (2) fetch the id of the current program instance. -- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix :code:`C` -- Statement (5) constructs a range of indices along the reduction axis: :code:`rk = [0, 1, ..., TK - 1]` -- Statement (6) initialize a :code:`TM x TN` array of accumulators to hold the result of :code:`A[rm, :] x B[:, rn]` -- Statements (7) - (8) initializes arrays of pointers :code:`pa` and :code:`pb` to the operands :code:`A` and :code:`B` using logic similar to that of the above transposition kernel -- Statements (9) - (10) load tiles of operands by dereferencing :code:`pa` and :code:`pb` -- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator :code:'@' -- Statements (12) - (13) updates :code:`pa` and :code:`pb` -- Statement (14) creates an array of pointers `pc` to the result matrix :code:`C` -- Statement (15) writes back the accumulator to :code:`C` - -Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel: - -- Automatic coalescing of load/store operations -- Automatic vectorization of load/store operations -- Stashing `a` and `b` to shared memory -- Automatic allocation of shared memory -- Automatic synchronization of shared memory -- Automatic padding of shared memory to avoid bank conflicts -- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0 - - -============== -Optimizations -============== - -Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code. - -------------- -Pre-Fetching -------------- - -The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows: - - -.. code-block:: C - - // pre-fetch operands - TYPE a[TM, TK] = *pa; //(9) - TYPE b[TK, TN] = *pb; //(10) - for(int k = K; k > 0; k-= TK){ - c += a @ b; - pa = pa + TK * 1; - pb = pb + TK * ldb; - // don't prefetch last iteration - bool check = k > TK; - // pre-fetch operands - a = check ? *pa : 0; - b = check ? *pb : 0; - } - - -Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard. - ------------------ -Rematerialization ------------------ - -`Rematerialization `_ is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only :code:`rm` and :code:`rn` need to be rematerialized, leading to the following epilogue: - -.. code-block:: C - - // epilogue - int rcm[TM] = pm * TM + 0 ... TM; - int rcn[TN] = pn * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc; - *pc = c; - - ------------------------------------- -Fused Transpositions and Auto-Tuning ------------------------------------- - -It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel: - -.. code-block:: C - - // Triton-C - // launched on a grid of (M / TM) x (N / TN) programs - void dot(TYPE * A, TYPE * B, TYPE * C, - int M, int N, int K, - int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) { - // prologue - int pm = get_program_id(0); - int pn = get_program_id(1); - int rm[TM] = pm * TM + 0 ... TM; - int rn[TN] = pn * TN + 0 ... TN; - int rk[TK] = 0 ... TK; - float c[TM, TN] = 0; - // pointers to operands - TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; - // prefetches operands - TYPE a[SHAPE_A] = (*pa); - TYPE b[SHAPE_B] = (*pb); - // reduction loop - for(int k = K; k > 0; k-= TK){ - c += USE_A @ USE_B; - pa = pa + TK * STRIDE_AK; - pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; - } - // epilogue - int rcm[TM] = pm * TM + 0 ... TM; - int rcn[TN] = pn * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc; - *pc = c; - } - - -All matrix multiplications variants can then be retrieved using the following compilation option: - -.. code-block:: C - - // A is not transposed - -DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda - -DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK - // A is transposed - -DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1 - -DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM - // B is not transpose - -DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1 - -DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN - // B is transpose - -DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb - -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK - - -Auto-tuning can also be handled using pre-processor macros: - -.. code-block:: C - - // Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16} - -DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16] - -A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/matrix-transposition.rst b/docs/tutorials/matrix-transposition.rst deleted file mode 100644 index d51ff6b41..000000000 --- a/docs/tutorials/matrix-transposition.rst +++ /dev/null @@ -1,174 +0,0 @@ -********************* -Matrix Transpositions -********************* - - -Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM. - -Of course, this can be fixed by using shared memory as shown `here `_, but this comes at the cost of simplicity interferes with auto-tuning. - -============== -Compute Kernel -============== - -In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like: - -.. code-block:: C - - // launched on a grid of (M / TM) x (N / TN) programs of 1 thread each - __global__ void transpose(TYPE * X, TYPE * Y, - int M, int N, int ldx, int ldy) { - // extract program ID - int pidm = get_program_id(0); //(1) - int pidn = get_program_id(1); //(2) - // create 1D range along the two matrix's axes - int rm[TM] = pidm * TM + 0 ... TM; //(3) - int rn[TN] = pidn * TN + 0 ... TN; //(4) - // create 2D array of pointers - TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5) - TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6) - // write back using the transposition operator '^' - *py = ^(*px); //(7) - } - -At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows: - -- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values: - - .. code-block:: C - - pidm = 2 - pidn = 1 - - -- Statements (3) and (4) construct the ranges of indices: - - .. code-block:: C - - rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)] - rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)] - - -which will be used in statements (5) and (6) to construct tiles of pointers - -- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics: - -:: - - │ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │ - │ ⋮ ⋮ │ - │ ⋮ ⋮ │ - │ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │ - - -- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics: - -:: - - │ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │ - │ ⋮ ⋮ │ - │ ⋮ ⋮ │ - │ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │ - -- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`. - - -================================== -A Note on Numpy-style Broadcasting -================================== - -The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example. - -First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix. - -.. code-block:: python - - import numpy as np - - vec = np.linspace(0,11,12) - mat = vec.reshape((4,3)) - -Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom). - -:: - - [[ 0, 1, 2], - [ 3, 4, 5], - [ 6, 7, 8], - [ 9, 10, 11]] - -Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values: - -:: - - [ 0, 1, 2], - [ 3, 4, 5], - -We construct ranges `rm` and `rn` as: -:: - - rm = [0, 1] - rn = [0, 1, 2] - -Using numpy-style broadcasting, we can add these together to create a matrix: - -:: - - rm[:, np.newaxis] + rn[np.newaxis, :] - - rn -> [0, 1, 2] - rm -> [0., [[0, 1, 2], - 1.] [1, 2, 3]] - -The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index -of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need -to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension). -In this case this is 3, so what we really need is: - -:: - - ldx = 3 - px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:] - -`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example). - -========================== -The __multipleof attribute -========================== - -The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible. - -For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows: - -.. code-block:: C - - __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, - int ldx __multipleof(8), - int ldy __multipleof(8)) { - // ... - } - - -========================== -Bounds Checking -========================== - - -You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below: - -.. code-block:: C - - // launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs - __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) { - // ... - // create bounds-checking mask - bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a) - bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b) - // conditional write-back using the conditional dereferencing operatior '*?()' - *?(checky)py = ^(*?(checkx)px); //(7) - } - - -Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`. - -A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/putting-it-all-together.rst b/docs/tutorials/putting-it-all-together.rst deleted file mode 100644 index 693a044c9..000000000 --- a/docs/tutorials/putting-it-all-together.rst +++ /dev/null @@ -1,180 +0,0 @@ -==================================================== -Putting It All Together -==================================================== - -In the previous tutorial, we saw how to write tensor-core-friendly matrix multiplication code competitive with cuBLAS in 20 lines of Triton code. Here, we will see how to wrap it into an automatically differentiable PyTorch functions for easy integration in your Deep Learning pipeline. - ------------------ -PyTriton Function ------------------ - -The PyTriton API provides a :code:`triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class - -.. code-block:: python - - import triton - - # Entry point - class _dot(torch.autograd.Function): - - @staticmethod - # Forward Pass - def forward(ctx, *args): - #... - - @staticmethod - # Backward Pass - def backward(ctx, dy): - #... - ------------------ -PyTriton Kernels ------------------ - - -PyTriton also provides a :code:`triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton code shown at the end of the previous tutorial. - -.. code-block:: python - - src = """ - __global__ void dot(TYPE * A, TYPE * B, TYPE * C, - int M, int N, int K, - int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) { - // prologue - int pm = get_program_id(0); - int pn = get_program_id(1); - int rm[TM] = pm * TM + 0 ... TM; - int rn[TN] = pn * TN + 0 ... TN; - int rk[TK] = 0 ... TK; - float c[TM, TN] = 0; - // pointers to operands - TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; - // prefetches operands - TYPE a[SHAPE_A] = (*pa); - TYPE b[SHAPE_B] = (*pb); - // reduction loop - for(int k = K; k > 0; k-= TK){ - c += USE_A @ USE_B; - pa = pa + TK * STRIDE_AK; - pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; - } - // epilogue - int rcm[TM] = pm * TM + 0 ... TM; - int rcn[TN] = pn * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc; - *pc = c; - } - """ - - kernel = triton.kernel(src) - - -At this point, `kernel` is a callable object which takes the same signature as the :code:`dot` function in our source code, except that pointers are treated as tensors: :code:`[tensor, tensor, tensor, int, int, int, int, int, int]`. - ------------------------ -Using PyTriton Kernels ------------------------ - - -However, in practice only A, B are provided by the user, and all the other :code:`int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the :code:`A` and :code:`B` tensors, and then returns the results of a call to :code:`kernel`: - -.. code:: python - - @staticmethod - def _call(a, b, transpose_a, transpose_b): - # extract shapes - shape_a = a.shape - shape_b = b.shape - M, Ka = shape_a[0], shape_a[1] - Kb, N = shape_b[0], shape_b[1] - # transpose shapes - if transpose_a: - M, Ka = Ka, M - if transpose_b: - Kb, N = N, Kb - # contiguous dimensions - lda = M if transpose_a else Ka - ldb = Kb if transpose_b else N - ldc = N - # data-type - dtype = a.dtype - # allocate output - c = triton.empty([M, N], dtype = dtype) - # launch grid - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] - # pre-processor definitions - defines = {# tile sizes - 'TYPE' : dtype, - 'AT' : transpose_a, - 'BT' : transpose_b, - 'TM' : [32, 64, 128], - 'TN' : [32, 64, 128], - 'TK' : [8], - # handle A transposition - 'USE_A' : '^a' if transpose_a else 'a', - 'STRIDE_AK' : 'lda' if transpose_a else '1', - 'STRIDE_AM' : '1' if transpose_a else 'lda', - 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', - 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', - 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', - # handle B transposition - 'USE_B' : '^b' if transpose_b else 'b', - 'STRIDE_BK' : '1' if transpose_b else 'ldb', - 'STRIDE_BN' : 'ldb' if transpose_b else '1', - 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', - 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', - 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, - grid=grid, num_warps=4, defines=defines) - - --------------------------------------------- -Automatic Differentiation --------------------------------------------- - -At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed. - - -Creating custom operations for Triton and PyTorch is very similar; programmers have to provide two static methods :code:`forward` and :code:`backward` that take a context as their first input: - -.. code:: python - - @staticmethod - def forward(ctx, a, b, transpose_a = False, transpose_b = False): - ctx.save_for_backward(a, b) - ctx.t_a = transpose_a - ctx.t_b = transpose_b - return _dot._call(a, b, transpose_a, transpose_b) - - @staticmethod - def backward(ctx, dy): - a, b = ctx.saved_tensors - t_a, t_b = ctx.t_a, ctx.t_b - if not t_a and not t_b: - da = _dot._call(dy, b, False, True) - db = _dot._call(a, dy, True, False) - elif not t_a and t_b: - da = _dot._call(dy, b, False, False) - db = _dot._call(dy, a, True, False) - elif t_a and not t_b: - da = _dot._call(b, dy, False, True) - db = _dot._call(a, dy, False, False) - elif t_a and t_b: - da = _dot._call(b, dy, True, True) - db = _dot._call(dy, a, True, True) - else: - assert False - return da, db, None, None, None, None, None, None, None - - -A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class. - -.. code:: python - - dot = _dot.apply - - -And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~ diff --git a/docs/tutorials/triton-vs-cuda.rst b/docs/tutorials/triton-vs-cuda.rst deleted file mode 100644 index c90190313..000000000 --- a/docs/tutorials/triton-vs-cuda.rst +++ /dev/null @@ -1,109 +0,0 @@ -*************** -Triton vs. CUDA -*************** - - - -The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code: - -.. code-block:: C - - // Triton - // launch on a grid of (N + TILE - 1) / TILE programs - __global__ void add(float* z, float* x, float* y, int N){ - int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE; - bool check[TILE] = offset < N; - float* pz[TILE] = z + offset; - float* px[TILE] = x + offset; - float* py[TILE] = y + offset; - *?(check)pz = *?(check)*px + *?(check)py; - } - -And its CUDA equivalent: - - -.. code-block:: C - - // CUDA - // launch on a grid of (N + TILE - 1) / TILE programs - __global__ void add(float *z, float *x, float *y, int N) { - int off = blockIdx.x * TILE + threadIdx.x; - if(off < N){ - float *pz = z + off; - float *px = x + off; - float *py = y + off; - *pz = *px + *py - } - } - - -========================== -Automatic parallelization -========================== - -While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**. - -This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations: - -- *Automatic shared memory allocation and synchronization* - - That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts. - -- *Automatic memory coalescing* - - When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible. - -- *Automatic tensor core utilization* - - Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together: - - .. code-block:: C - - half A[16, 8] = ... // initialize A - half B[8, 16] = ... // initialize B - float C[16, 16] = dot(A, B); // uses Tensor Cores! - - -- *Automatic instruction predication* - - Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals. - -=========================== -Vector Addition - Revisited -=========================== - -In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following: - -.. code-block:: C - - // CUDA - // launch on a grid of (N + TILE - 1) / TILE programs - __global__ void add(float *z, float *x, float *y, int N) { - int off[4]; - #pragma unroll - for(int k = 0; k < 4; k++) - off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x; - #pragma unroll - for(int k = 0; k < 4; k++) - z[off[0]] = x[off[0]] + y[off[0]] - } - -This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread. - -============================ -Auto-Tuning -============================ - -Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example: - -.. code-block:: python - - kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8]) - -would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel. - -============================= -Going Further -============================= - -The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture! \ No newline at end of file diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index b29813fd4..6184e7364 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -29,7 +29,7 @@ public: static driver::stream* create(backend_t backend); // methods virtual void synchronize() = 0; - virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args = NULL, size_t args_size = 0) = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args = NULL, size_t args_size = 0) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; // template helpers @@ -44,7 +44,7 @@ class host_stream: public stream { public: host_stream(); void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -55,7 +55,7 @@ public: cu_stream(CUstream str, bool take_ownership); cu_stream(); void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index b5b96f1bb..319bc6fdb 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -4,27 +4,19 @@ #define _TRITON_RUNTIME_FUNCTION_H_ #include +#include #include #include #include #include #include -#include // codegen #include "triton/ir/context.h" -#include "triton/codegen/target.h" #include "triton/runtime/arg.h" #include "triton/runtime/error.h" -namespace llvm { - class Module; - class LLVMContext; -} - -class Parser; - +// driver forward declaration namespace triton { - namespace driver{ class module; class stream; @@ -32,26 +24,19 @@ namespace driver{ class context; class device; } - -namespace lang{ -class translation_unit; } - -namespace codegen{ -namespace analysis{ -class tiles; -} -} - +// ir forward declaration +namespace triton{ namespace ir { class module; class function; class context; } +} +namespace triton{ namespace runtime{ - typedef std::vector grid_t; typedef std::map params_t; template inline T convert(const std::string& name); @@ -72,8 +57,7 @@ enum asm_mode_t { struct options_space_t { typedef std::pair> define_t; std::vector defines; - std::vector num_warps; - std::vector recompile_key; + std::vector num_warps; }; struct options_t { @@ -81,88 +65,69 @@ struct options_t { T D(const std::string& name) const { return convert(defines.at(name)); } - bool operator<(const options_t& other) const { - return std::make_pair(defines, num_warps) < - std::make_pair(other.defines, other.num_warps); - } - std::string to_str() const; - - std::map defines; + std::unordered_map defines; size_t num_warps; }; + +/* ------------------------- */ + +class kernel{ +private: + static std::string preheader(); + static arg_type convert(ir::type *ty); + +public: + kernel(const std::string& src, const options_t& opt, driver::device *device); + void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector& grid) const; + // getters + const std::vector& get_sig() const { return sig_; } + +private: + void init_ir (const std::string &src); + void init_ker(); + void init_sig(); + +public: + const options_t opt; + +private: + driver::device* dev_; + // signature + std::vector sig_; + // triton context for parsing + ir::context ctx_; + // handles + std::shared_ptr ir_; + std::shared_ptr mod_; + std::shared_ptr ker_; +}; + class function { public: typedef std::function grid_fn_ty; + typedef std::pair> kernel_pair_t; + typedef std::map, kernel*> cache_t; private: - class caller { - public: - // constructors - caller(std::ifstream& ifs, const options_t& opt); - caller(ir::function *ir, std::shared_ptr program, const options_t& opt); - // serialization - void write(std::ofstream& ofs); - void read(std::ifstream& ifs); - // accessors - const options_t opt() const { return opt_; } - const driver::module* parent() const { return &*parent_; } - const driver::kernel* bin() const { return &*bin_; } - arg_type param_ty(size_t i) const { return param_tys_.at(i);} - const std::vector& param_tys() const { return param_tys_; } - - std::vector retune() const { return retune_; } - // entry points - void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map>& = {}) const; - - private: - std::shared_ptr bin_; - std::shared_ptr parent_; - std::vector param_tys_; - std::vector retune_; - options_t opt_; - std::string name_; - }; - -private: - typedef std::pair> cache_key_t; - -private: - // cache - static std::string get_cache_prefix(); - // make - triton::lang::translation_unit *make_ast(const std::string &src); - std::unique_ptr make_ir(Parser &parser); - std::unique_ptr make_bin(ir::module &function, driver::device *device, const options_t &opt); - void make(driver::device *device, options_t opt); - void precompile(driver::device *device, const options_space_t& tuning_space); - // autotune - caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size); - + static void do_loop_nest(std::vector const & ranges, + std::function const &)> const & f); public: - static std::string preheader(); - -public: - function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); - void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device); - void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device); - void set_cst(const char* name, void* data, size_t n_bytes); - std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt); + function(const std::string& src, const options_space_t& opt, driver::device *device); + void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); + void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream); + // auto-tuning + cache_t::iterator find_in_cache(void* args, size_t args_size); + kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); + // getters + const std::vector get_kernels() { return kernels_; } private: - std::map> cst_; - // pre-compilation - ir::context ctx_; - std::string src_; - options_space_t opt_; - std::set compiled_; - std::map> callers_; - std::vector args_off_; - size_t args_size_; - // caching - std::string cache_ref_; - std::string cache_path_; - std::map cache_; + void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device); + +private: + std::vector kernels_; + std::map, kernel*> cache_; }; } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 550061975..77e39e4f5 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -168,7 +168,6 @@ mma_layout::mma_layout(size_t num_warps, bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; - pack_size_0 = pack_size_1 = 1; rep_ = {2*pack_size_0, 2*pack_size_1, 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; } diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 7b25e4c18..ff349c9ce 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -76,7 +76,7 @@ void host_stream::synchronize() { hst_->args.clear(); } -void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void **args, size_t args_size) { +void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size) { auto hst = kernel->module()->hst(); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); char* params = new char[args_size]; @@ -113,7 +113,7 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void** args, size_t args_size) { +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size) { void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index ca9e4bb03..7f648ab17 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -108,11 +108,9 @@ void Parser::ParseTranslationUnit() { FuncDef* Parser::ParseFuncDef(Identifier* ident) { auto funcDef = EnterFunc(ident); - if (funcDef->FuncType()->Complete()) { Error(ident, "redefinition of '%s'", funcDef->Name().c_str()); } - // TODO(wgtdkp): param checking auto funcType = ident->Type()->ToFunc(); funcType->SetComplete(true); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index c6e4a5032..f23c720c9 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -45,360 +45,25 @@ std::mutex mut; namespace triton{ namespace runtime { -/* --------------------- */ -/* HELPERS */ -/* --------------------- */ +/* --------------------------------- */ +/* --------------------------------- */ +/* --------------------------------- */ -void _loop_nest(std::vector const & ranges, - std::function const &)> const & f){ - size_t D = ranges.size(); - std::vector values(D, 0); - size_t i = D - 1; - while(true){ - f(values); - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; - } -} - - -/* --------------------- */ -/* OPTIONS */ -/* --------------------- */ - -std::string options_t::to_str() const{ - std::string ret = "nw-" + std::to_string(num_warps); - for(const auto& x : defines){ - ret += '-'; - ret += x.first; - ret += '-'; - ret += x.second; - } - // legalize - for(char& x: ret){ - if(x == ' ' || x == '^' || x == ',' || x == ':') - x = '_'; - } - return ret; -} - - -/* --------------------- */ -/* CALLER OBJECT */ -/* --------------------- */ - -arg_type convert(ir::type *ty) { - if(ty->is_integer_ty(1)) - return INT1_T; - if(ty->is_integer_ty(8)) - return INT8_T; - if(ty->is_integer_ty(16)) - return INT16_T; - if(ty->is_integer_ty(32)) - return INT32_T; - if(ty->is_integer_ty(64)) - return INT64_T; - if(ty->is_half_ty()) - return HALF_T; - if(ty->is_float_ty()) - return FLOAT_T; - if(ty->is_double_ty()) - return DOUBLE_T; - if(ty->is_pointer_ty()) - return BUFFER_T; +arg_type kernel::convert(ir::type *ty) { + if(ty->is_integer_ty(1)) return INT1_T; + if(ty->is_integer_ty(8)) return INT8_T; + if(ty->is_integer_ty(16)) return INT16_T; + if(ty->is_integer_ty(32)) return INT32_T; + if(ty->is_integer_ty(64)) return INT64_T; + if(ty->is_half_ty()) return HALF_T; + if(ty->is_float_ty()) return FLOAT_T; + if(ty->is_double_ty()) return DOUBLE_T; + if(ty->is_pointer_ty()) return BUFFER_T; throw std::runtime_error("unknown type"); } -//void function::caller::write(std::ofstream &ofs) { -// // write name -// ofs << name_ << std::endl; -// // write signature -// for(size_t i = 0; i < param_tys_.size(); i++) -// ofs << param_tys_[i] << " "; -// ofs << std::endl; -// // write module -// std::string source = ((driver::cu_module*)(&*parent_))->ptx(); -// ofs << source; -//} -//void function::caller::read(driver::context* ctx, std::ifstream &ifs) { -// // read name -// std::getline(ifs, name_); -// // read signature -// std::string line; -// std::getline(ifs, line); -// std::istringstream current(line); -// int param; -// param_tys_.clear(); -// while(current >> param) -// param_tys_.push_back((arg_type)param); -// // read module -// std::string src((std::istreambuf_iterator(ifs)), -// std::istreambuf_iterator()); -// parent_.reset(new driver::cu_module(ctx, src)); -// bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); - -//} - -//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt) -// : opt_(opt) { -// read(ctx, ifs); -//} - -function::caller::caller(ir::function *ir, - std::shared_ptr parent, const options_t& opt) - : parent_(parent), opt_(opt), name_(ir->get_name()) { - bin_.reset(driver::kernel::create(&*parent, name_.c_str())); - // extract signature - ir::function_type* ty = ir->get_fn_type(); - for(size_t i = 0; i < ty->get_num_params(); i++){ - param_tys_.push_back(convert(ty->get_param_ty(i))); - if(!ir->has_attr(i+1)) - continue; - for(ir::attribute attr: ir->attrs().at(i + 1)) - if(attr.get_kind() == ir::retune) - retune_.push_back(i); - } -} - - -void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size, const std::map>& csts) const { - // copy constants - for(const auto& cst: csts){ - std::unique_ptr buffer = parent()->symbol(cst.first.c_str()); - stream->write(&*buffer, true, 0, cst.second); - } - // set grid - if(_grid.size() > 3) - throw std::runtime_error("grid size must be no greater than 3"); - std::array grid; - for(size_t i = 0; i < 3; i++) - grid[i] = (i < _grid.size()) ? _grid[i] : 1; - // enqueue - stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, args, args_size); -} - - -/* --------------------- */ -/* FUNCTION */ -/* --------------------- */ - -// create Triton-IR from AST -std::unique_ptr function::make_ir(Parser& parser) { - ir::module* module = new ir::module("", ctx_); - Generator gen(&parser); - gen.Gen(module); - return std::unique_ptr(module); -} - -// create Binary from Triton-IR -std::unique_ptr function::make_bin(ir::module &module, driver::device* device, const options_t& opt) { - std::unique_ptr target = device->make_target(); - // generate llvm code - llvm::LLVMContext ctx; - std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); - // optimizations - bool cts_use_async = target->as_nvidia()->sm() >= 80; - // create passes - codegen::analysis::align align; - codegen::analysis::axes axes; - codegen::transform::cts cts(cts_use_async); - codegen::transform::disassociate disassociate; - codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); - codegen::analysis::liveness liveness(&layouts); - codegen::analysis::swizzle swizzle(&layouts, target.get()); - codegen::analysis::allocation allocation(&liveness); - codegen::transform::membar barriers(&liveness, &layouts, &allocation); - codegen::transform::dce dce; - codegen::transform::peephole peephole(target.get()); - codegen::transform::reassociate reassociate; - codegen::transform::coalesce coalesce(&align, &layouts); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); - // run passes - dce.run(module); - disassociate.run(module); - dce.run(module); - peephole.run(module); - dce.run(module); - align.run(module); - if(target->is_gpu()) - cts.run(module); - axes.run(module); - layouts.run(module); - coalesce.run(module); - dce.run(module); - align.run(module); - dce.run(module); - if(target->is_gpu()){ - reassociate.run(module); - cts.run(module); - } - peephole.run(module); - dce.run(module); - align.run(module); - axes.run(module); -// ir::print(module, std::cout); - layouts.run(module); - swizzle.run(module); - liveness.run(module); - allocation.run(module); - if(allocation.allocated_size() > device->max_shared_memory()) - throw exception::out_of_shared_memory(); - barriers.run(module); - isel.visit(module, *llvm); - std::unique_ptr res(driver::module::create(device, std::move(llvm))); -// if(res->spilled() > 256) -// throw exception::out_of_registers(); - return res; -} - - -// create Binary from options -void function::make(driver::device *device, options_t opt) { - if(callers_.find(opt) != callers_.end()) - return; - // pre-process - TokenSequence tokens; - Preprocessor cpp(&src_, true); - for(auto it: opt.defines) - cpp.AddMacro(it.first, &it.second); - cpp.Process(tokens); - // src -> ast - Parser parser(tokens); - parser.Parse(); - // ast -> triton-ir - auto ir = make_ir(parser); - // triton-ir -> binary - std::unique_ptr bin; - try{ - bin = make_bin(*ir, device, opt); - }catch(const exception::base&){ - throw; - } - // create callable - ir::function *tmp = ir->get_function_list()[0]; - callers_[opt].reset(new caller(tmp, std::move(bin), opt)); -} - -// precompile all kernels spanned by given options space -void function::precompile(driver::device* device, const options_space_t& space) { - // all ranges - std::vector ranges; - ranges.push_back(space.num_warps.size()); - for(const auto& x: space.defines) - ranges.push_back(x.second.size()); - // functor for source with given option - std::map err; - auto do_make = [&](std::vector params) { - // compilation options - unsigned i = 0; - options_t opt; - opt.num_warps = space.num_warps[params[i++]]; - for(auto D: space.defines) - opt.defines[D.first] = D.second[params[i++]]; - // compile - try{ - make(device, opt); - }catch(const exception::base& e){ - err[opt] = e.what(); - } - }; - // multi-threaded compilation - _loop_nest(ranges, do_make); - if(callers_.empty()){ - std::ostringstream dbg; - dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; - for(auto x: err){ - dbg << "[ "; - dbg << x.first.num_warps << ", "; - dbg << "{ "; - for(const auto& y: x.first.defines) - dbg << '"' << y.first << "\"= \"" << y.second << "\", "; - dbg << " } ] -> " << x.second << std::endl; - } - throw exception::no_valid_configuration(dbg.str()); - } -} - -std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) { - make(device, opt); - const auto& fn = callers_.at(opt); - if(!fn) - return ""; - switch(mode){ - case ASM_LLIR:{ - return fn->parent()->llir(); - } - case ASM_NV_PTX: - case ASM_NV_SASS:{ - std::string ptx = ((driver::cu_module*)fn->parent())->ptx(); - // SASS - std::string input = std::tmpnam(nullptr); - std::string output = std::tmpnam(nullptr); - std::ofstream ofs(input); - ofs << ptx; - ofs.close(); - if(mode == ASM_NV_PTX) - return ptx; - std::string cmd; - int err; - // compile ptx - driver::cu_device* cu_device = (driver::cu_device*)device; - cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o"; - err = system(cmd.c_str()); - // disassemble - cmd = "cuobjdump --dump-sass " + input + ".o >> " + output; - err = system(cmd.c_str()); - std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/"); - std::string to_delete = " /*"; - std::ifstream ifs(output); - std::string line; - std::string sass; - while(std::getline(ifs, line)) - if(!std::regex_match(line, comment)) - sass += line + "\n"; - return sass; - } - default: - return ""; - } - - -} - -// returns program with best compilation options for given parameter -function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, - void** args, size_t args_size) { - // fast path -- no autotuning necessary - if(callers_.size() == 1) - return &*callers_.begin()->second; - // run auto-tuner - double best_ts = INFINITY; - caller* ret = nullptr; - for(auto &x : callers_){ - if(x.second == nullptr) - throw std::runtime_error("configuration not compiled"); - caller* current = &*x.second; - double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size, cst_); }, - stream, true); - ret = (ts < best_ts) ? current : ret; - best_ts = std::min(ts, best_ts); - } - stream->synchronize(); - return ret; -} - -// set copy host buffer "data" into constant memory buffer "name" -void function::set_cst(const char* name, void* data, size_t n_bytes) { - cst_[std::string(name)] = std::vector((char*)data, (char*)data + n_bytes); -} - - -std::string function::preheader() { +std::string kernel::preheader() { return R"( #define bool _Bool #define true 1 @@ -452,67 +117,212 @@ typedef long int64; )"; } -std::string function::get_cache_prefix() { - //user-specified cache path - std::string result = tools::getenv("TRITON_CACHE_PATH"); - if(!result.empty()){ - if(tools::mkpath(result)==0) - return result; - } - //create in home - result = tools::getenv("HOME"); - if(!result.empty()) - { - result = result + "/.triton/cache/"; - if(tools::mkpath(result)==0) - return result; - } - return ""; +void kernel::init_ir(const std::string& src) { + // pre-process + TokenSequence tokens; + Preprocessor cpp(&src, true); + for(auto it: opt.defines) + cpp.AddMacro(it.first, &it.second); + cpp.Process(tokens); + // src -> ast + Parser parser(tokens); + parser.Parse(); + // ast -> triton-ir + ir::module* module = new ir::module("", ctx_); + Generator gen(&parser); + gen.Gen(module); + ir_.reset(module); } -function::function(const std::string &src, - const options_space_t& opt, - const std::string &cache_ref): - src_(src), opt_(opt), cache_ref_(cache_ref) { - // hash source code - unsigned char hash[20]; - sha1::calc((void*)src_.data(), src_.size(), hash); - // create cache path - char _hex[40]; - sha1::toHexString(hash, _hex); - std::string hex(_hex, _hex + 40); - cache_path_ = get_cache_prefix() + hex + "/"; - tools::mkpath(cache_path_); - // append pre-header to source - src_ = preheader() + src_; +void kernel::init_ker(){ + // triton-ir -> binary + std::unique_ptr bin; + std::unique_ptr target = dev_->make_target(); + // generate llvm code + llvm::LLVMContext ctx; + std::string name = ir_->get_function_list()[0]->get_name(); + std::unique_ptr llvm(new llvm::Module(name, ctx)); + // optimizations + bool cts_use_async = target->as_nvidia()->sm() >= 80; + // create passes + codegen::analysis::align align; + codegen::analysis::axes axes; + codegen::transform::cts cts(cts_use_async); + codegen::transform::disassociate disassociate; + codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); + codegen::analysis::liveness liveness(&layouts); + codegen::analysis::swizzle swizzle(&layouts, target.get()); + codegen::analysis::allocation allocation(&liveness); + codegen::transform::membar barriers(&liveness, &layouts, &allocation); + codegen::transform::dce dce; + codegen::transform::peephole peephole(target.get()); + codegen::transform::reassociate reassociate; + codegen::transform::coalesce coalesce(&align, &layouts); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); + // run passes + dce.run(*ir_); + disassociate.run(*ir_); + dce.run(*ir_); + peephole.run(*ir_); + dce.run(*ir_); + align.run(*ir_); + if(target->is_gpu()) + cts.run(*ir_); + axes.run(*ir_); + layouts.run(*ir_); + coalesce.run(*ir_); + dce.run(*ir_); + align.run(*ir_); + dce.run(*ir_); + if(target->is_gpu()){ + reassociate.run(*ir_); + cts.run(*ir_); + } + peephole.run(*ir_); + dce.run(*ir_); + align.run(*ir_); + axes.run(*ir_); + layouts.run(*ir_); + swizzle.run(*ir_); + liveness.run(*ir_); + allocation.run(*ir_); + if(allocation.allocated_size() > dev_->max_shared_memory()) + throw exception::out_of_shared_memory(); + barriers.run(*ir_); + isel.visit(*ir_, *llvm); + //if(res->spilled() > 256) + // throw exception::out_of_registers(); + mod_.reset(driver::module::create(dev_, std::move(llvm))); + ker_.reset(driver::kernel::create(&*mod_, name.c_str())); } -void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream, driver::device *device) { - // pre-compile kernels - if(callers_.empty()){ - precompile(device, opt_); +void kernel::init_sig() { + ir::function* fn = ir_->get_function_list()[0]; + ir::function_type* ty = fn->get_fn_type(); + for(size_t i = 0; i < ty->get_num_params(); i++){ + sig_.push_back(convert(ty->get_param_ty(i))); + if(!fn->has_attr(i+1)) + continue; } - // re-tuning key - cache_key_t key; - key.first = device; - key.second = callers_.begin()->second->retune(); - // auto-tune if necessary +} + +kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev): + opt(opt), dev_(dev) { + init_ir(preheader() + src); + init_ker(); + init_sig(); +} + +void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector& _grid) const{ + // set grid + if(_grid.size() > 3) + throw std::runtime_error("grid size must be no greater than 3"); + std::array grid; + for(size_t i = 0; i < 3; i++) + grid[i] = (i < _grid.size()) ? _grid[i] : 1; + // enqueue + stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size); +} + +/* --------------------------------- */ +/* --------------------------------- */ +/* --------------------------------- */ + +void function::do_loop_nest(std::vector const & ranges, + std::function const &)> const & f){ + size_t D = ranges.size(); + std::vector values(D, 0); + size_t i = D - 1; + while(true){ + f(values); + while(values[i]++ == ranges[i] - 1){ + if(i == 0) + return; + values[i--] = 0; + } + i = D - 1; + } +} + + +void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) { + // all ranges + std::vector ranges; + ranges.push_back(opts.num_warps.size()); + for(const auto& x: opts.defines) + ranges.push_back(x.second.size()); + // functor for source with given option + std::vector> err; + auto do_make = [&](std::vector params) { + // compilation options + unsigned i = 0; + options_t opt; + opt.num_warps = opts.num_warps[params[i++]]; + for(auto D: opts.defines) + opt.defines[D.first] = D.second[params[i++]]; + // compile + try{ + kernels_.push_back({opt, std::make_shared(src, opt, device)}); + }catch(const exception::base& e){ + err.push_back({opt, e.what()}); + } + }; + // multi-threaded compilation + do_loop_nest(ranges, do_make); + if(kernels_.empty()){ + std::ostringstream dbg; + dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; + for(auto x: err){ + dbg << "[ "; + dbg << x.first.num_warps << ", "; + dbg << "{ "; + for(const auto& y: x.first.defines) + dbg << '"' << y.first << "\"= \"" << y.second << "\", "; + dbg << " } ] -> " << x.second << std::endl; + } + throw exception::no_valid_configuration(dbg.str()); + } +} + +kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) { + // fast path -- no autotuning necessary + if(kernels_.size() == 1) + return &*kernels_.begin()->second; + // auto-tuning key + std::vector key; auto it = cache_.find(key); - if(it == cache_.end()){ - auto best = autotune(stream, grid_fn, args, args_size); - it = cache_.insert({key, best}).first; + if(it != cache_.end()) + return it->second; + // run auto-tuner + double best_ts = INFINITY; + kernel* ret = nullptr; + for(auto &x : kernels_){ + kernel* current = &*x.second; + auto grid = grid_fn(x.first); + while(grid.size() < 3) + grid.push_back(1); + double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, + stream, true); + ret = (ts < best_ts) ? current : ret; + best_ts = std::min(ts, best_ts); } - // run - (*it->second)(stream, grid_fn(it->second->opt()), args, args_size, cst_); + stream->synchronize(); + it = cache_.insert({key, ret}).first; + return it->second; } -void function::operator()(void** args, - size_t args_size, - const grid_t& grid, - driver::stream* stream, driver::device *device) { - return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream, device); +function::function(const std::string& src, const options_space_t& opt, driver::device *device) { + init_kernels(src, opt, device); } +void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { + runtime::kernel* fn = autotune(args, args_size, grid_fn, stream); + (*fn)(args, args_size, stream, grid_fn(fn->opt)); +} + +void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) { + return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream); +} } diff --git a/python/examples/tutorials/add.py b/python/examples/tutorials/add.py deleted file mode 100644 index 542efa57b..000000000 --- a/python/examples/tutorials/add.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -import triton - -class _add(torch.autograd.Function): - src = """ -__global__ void add(float* z, float* x, float* y, int N) { - - int pid = get_program_id(0); - - int offset[TILE] = pid * TILE + 0 ... TILE; - float* pz[TILE] = z + offset; - float* px[TILE] = x + offset; - float* py[TILE] = y + offset; - - bool check[TILE] = offset < N; - - *pz = *px + *py; -} - """ - - kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4]) - - @staticmethod - def forward(ctx, x, y): - z = torch.empty_like(x).cuda() - N = x.numel() - grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),) - _add.kernel(z,x,y, N, grid=grid) - return z - -add = _add.apply - -# test -torch.manual_seed(0) -x = torch.rand(900).cuda() -y = torch.rand(900).cuda() -za = x + y -zb = add(x, y) -print(torch.allclose(za,zb)) diff --git a/python/examples/tutorials/copy.py b/python/examples/tutorials/copy.py deleted file mode 100644 index 7c5276242..000000000 --- a/python/examples/tutorials/copy.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import triton - -class _copy(torch.autograd.Function): - src = """ - __global__ void copy(TYPE * X, TYPE * Y, - int M __retune, - int N __retune, - int ldx __multipleof(8)) { - // extract program ID - int pidm = get_program_id(0); //(1) - int pidn = get_program_id(1); //(2) - - // create 1D range along the two matrix's axes - int rm[TM] = pidm * TM + 0 ... TM; //(3) - int rn[TN] = pidn * TN + 0 ... TN; //(4) - - // create 2D array of pointers - TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5) - TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6) - - *py = *px; - } - """ - - kernel = None ### initialize later when we know the sizes - - @staticmethod - def forward(ctx, x): - - M, N = x.shape - - ldx = N; - - dtype = x.dtype - - y = torch.empty((M,N)).cuda() - - defines= { - 'TYPE' : dtype, - 'TM' : [32,64,128], - 'TN' : [32,64,128], - } - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] - - if _copy.kernel is None: - _copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4]) - - _copy.kernel(x, y, M, N, ldx, grid=grid) - - return y - -copy = _copy.apply - -# test -torch.manual_seed(0) -x = torch.randn(8,4).cuda() - -print(x) - -ya = x -yb = copy(x) - -print() -print(ya) -print() -print(yb) -print(torch.allclose(ya, yb)) - -print(ya == yb) diff --git a/python/examples/tutorials/matmul.py b/python/examples/tutorials/matmul.py deleted file mode 100644 index 9b6904260..000000000 --- a/python/examples/tutorials/matmul.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -import triton - -class _dot(torch.autograd.Function): - src = """ -#define STM 4 -#define STN 4 - -__global__ void dot(TYPE * A __noalias __readonly __aligned(16), - TYPE * B __noalias __readonly __aligned(16), - TYPE * C __noalias __aligned(16), - float alpha, - int M __retune, - int N __retune, - int K __retune __multipleof(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8)) { - // prologue - int pid = get_program_id(0); - int pidz = get_program_id(2); - int gridm = M / TM; - int gridn = N / TN; - int stgridm = (gridm + STM - 1) / STM; - int stgridn = (gridn + STN - 1) / STN; - int stid = pid / (STM * STN); - int laneid = pid % (STM * STN); - int stm = stid / stgridn; - int stn = stid % stgridn; - int lanem = laneid / STN; - int lanen = laneid % STN; - int pidm = stm*STM + lanem; - int pidn = stn*STN + lanen; - int rm[TM] = pidm * TM + 0 ... TM; - int rn[TN] = pidn * TN + 0 ... TN; - - // reduction splitting - K = K / TZ; - int rk[TK] = pidz * K + 0 ... TK; - - // pointers to operands - int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; - int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; - TYPE* pa[TM, TK] = A + offa; - TYPE* pb[TK, TN] = B + offb; - - // prefetches operands - bool checka[TM, TK] = rk[newaxis, :] < K; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ - bool checka[TM, TK] = k > TK; - bool checkb[TK, TN] = k > TK; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - acc += a @ b; - a = *?(checka)pa; - b = *?(checkb)pb; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; - - // epilogue - int rxm[TM] = pidm * TM + 0 ... TM; - int rxn[TN] = pidn * TN + 0 ... TN; - int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); - -#if (TZ==1) - *?(checkc) pc = c; -#else - // accumulate partial result using spin-locks - int *plock = locks + pid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); -#endif -} - """ - - @staticmethod - def forward(ctx, a, b): - c = _dot._call(a,b) - return c - - - kernel = dict() - - @staticmethod - def _call(a, b): - # create kernel if necessary - dtype = a.dtype - if dtype not in _dot.kernel: - defines = { - 'TYPE' : dtype, - 'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN', - 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', - 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', - 'TM' : [128], - 'TN' : [128], - 'TK' : [32], - 'TZ' : [1] - } - _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) - kernel = _dot.kernel[dtype] - # allocate output - M, K = a.shape - K, N = b.shape - c = torch.empty([M,N], dtype=dtype, device=a.device) - print(kernel.asm('sass', c.device)) - print(kernel.asm('ptx', c.device)) - # enqueue - grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))] - time = kernel(a, b, c, 1., M, N, K, - a.stride(0), b.stride(0), c.stride(0), grid=grid) - return c - - -dot = _dot.apply - -torch.manual_seed(0) - -M, N, K = 4096, 4096, 4096 -a = torch.rand((M, K)).cuda().half() -b = torch.rand((K, N)).cuda().half() - -#a[:] = 1 -#b[:] = 1 - -zc = torch.matmul(a,b) -zc_ = dot(a,b) -print(torch.allclose(zc, zc_)) diff --git a/python/examples/tutorials/trans.py b/python/examples/tutorials/trans.py deleted file mode 100644 index be31cd2cc..000000000 --- a/python/examples/tutorials/trans.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import triton - -class _transpose(torch.autograd.Function): - src = """ - __global__ void transpose(TYPE * X, TYPE * Y, - int M __retune, - int N __retune, - int ldx __multipleof(8), int ldy __multipleof(8)) { - // extract program ID - int pidm = get_program_id(0); //(1) - int pidn = get_program_id(1); //(2) - - // create 1D range along the two matrix's axes - int rm[TM] = pidm * TM + 0 ... TM; //(3) - int rn[TN] = pidn * TN + 0 ... TN; //(4) - - // create 2D array of pointers - TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5) - TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6) - - // create bounds-checking mask - bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a) - bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b) - - // conditional write-back using the conditional dereferencing operatior '*?()' - *?(checky)py = ^(*?(checkx)px); //(7) - } - """ - - kernel = None ### initialize later when we know the sizes - - @staticmethod - def forward(ctx, x): - - M, N = x.shape - - ldx = N - ldy = M - - dtype = x.dtype - - y = torch.empty((N,M)).cuda() - - defines= { - 'TYPE' : dtype, - 'TM' : [32,64,128], - 'TN' : [32,64,128], - } - - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] - - if _transpose.kernel is None: - _transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4]) - - _transpose.kernel(x, y, M, N, ldx, ldy, grid=grid) - - return y - -transpose = _transpose.apply - -# test -torch.manual_seed(0) -x = torch.randn(1024,128).cuda() - -print(x) - -ya = torch.t(x) -yb = transpose(x) -print() -print(ya) -print() -print(yb) -print(torch.allclose(ya, yb)) - -print(ya == yb) diff --git a/python/setup.py b/python/setup.py index 898b7f6c5..c3c7e8842 100644 --- a/python/setup.py +++ b/python/setup.py @@ -95,25 +95,18 @@ class CMakeBuild(build_ext): subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) -find_llvm() - -directories = [x[0] for x in os.walk(os.path.join('src', 'include'))] -data = [] -for d in directories: - for htype in ['h', 'hpp']: - files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False) - data += [os.path.relpath(f, 'src') for f in files] - setup( name='triton', - version='0.3.0', + version='1.0.0', author='Philippe Tillet', - author_email='ptillet@g.harvard.edu', + author_email='phil@openai.com', description='A language and compiler for custom Deep Learning operations', long_description='', - packages=['triton', 'triton/_C'], - install_requires=['numpy', 'torch', 'sympy'], - package_data={'': data}, + packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'], + install_requires=['numpy', 'torch'], + package_data={'triton/ops': ['*.c'], + 'triton/ops/blocksparse': ['*.c']}, + include_package_data=True, ext_modules=[CMakeExtension('triton', 'triton/_C/')], cmdclass=dict(build_ext=CMakeBuild), zip_safe=False, @@ -122,7 +115,7 @@ setup( url='https://github.com/ptillet/triton/', download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', classifiers=[ - 'Development Status :: 4 - Beta', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + 'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 'Intended Audience :: Developers', # Define that your audience are developers 'Topic :: Software Development :: Build Tools', 'License :: OSI Approved :: MIT License', # Again, pick a license diff --git a/python/src/bindings.cc b/python/src/bindings.cc index f6f34836d..456860ce1 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -13,15 +13,19 @@ #include "triton/ir/function.h" using namespace triton; - namespace rt = triton::runtime; +namespace drv = triton::driver; typedef std::pair map_key_t; -std::map> id_grid_map; -std::map> id_fn_map; -CUstream torch_get_cuda_stream(int64_t dev_id); -CUdevice torch_get_cuda_device(int64_t dev_id); +std::map> id_grid_map; +std::map> id_fn_map; +std::map> tt_devices; +std::map> tt_streams; +std::unordered_map opt_cache_; +extern CUstream torch_get_cuda_stream(int64_t dev_id); +extern CUdevice torch_get_cuda_device(int64_t dev_id); + /* Grid utilities */ @@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) { /* Function utilities */ -void register_fn(const map_key_t& key, +void register_fn(int op_id, + int dev_id, const std::string& src, const rt::options_space_t& opt) { - if(id_fn_map.find(key) == id_fn_map.end()) - id_fn_map[key].reset(new rt::function(src, opt, "")); + if(tt_devices.find(dev_id) == tt_devices.end()) { + driver::device* device; + driver::stream* stream; + if(dev_id >= 0){ + device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false); + stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false); + } + else{ + device = new triton::driver::host_device(); + stream = new triton::driver::host_stream(); + } + tt_devices[dev_id].reset(device); + tt_streams[dev_id].reset(stream); + } + if(id_fn_map.find(op_id) == id_fn_map.end()){ + id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id])); + } + for(const auto& k: id_fn_map[op_id]->get_kernels()){ + const rt::options_t* opt = &k.first; + pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); + for(auto x: opt->defines) + if(std::all_of(x.second.begin(), x.second.end(), ::isdigit)) + obj.attr(x.first.c_str()) = std::stoi(x.second); + opt_cache_[&k.second->opt] = obj; + } + } -void delete_fn(const map_key_t& key) { - id_fn_map.erase(key); +void delete_fn(int op_id) { + id_fn_map.erase(op_id); } -std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) { - triton::driver::cu_device device(key.second, false); - return id_fn_map[key]->get_asm(mode, &device, opt); -} void cleanup() { id_grid_map.clear(); id_fn_map.clear(); + opt_cache_.clear(); } size_t make_op_id() { return id_fn_map.size(); } -/* Function signature */ -void make_module(const std::string& src, ir::module* ir, - const runtime::options_space_t& opt) { - std::string copy = triton::runtime::function::preheader() + src; - // pre-process - TokenSequence tokens; - Preprocessor cpp(©, true); - for(auto it: opt.defines){ - cpp.AddMacro(it.first, &it.second[0]); - } - cpp.Process(tokens); - // parse - Parser parser(tokens); - parser.Parse(); - Generator gen(&parser); - gen.Gen(ir); +std::vector get_fn_signature(size_t op_id) { + return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); } -std::vector get_fn_signature(const std::string& src, - const runtime::options_space_t& opt) { - // triton-ir code-gen - ir::context ctx; - auto ir = std::shared_ptr(new ir::module("", ctx)); - make_module(src, &*ir, opt); - // function - ir::function* fn = ir->get_function_list().front(); - // extract signature - std::vector ret; - ir::function_type* ty = fn->get_fn_type(); - for(size_t i = 0; i < ty->get_num_params(); i++) - ret.push_back(rt::convert(ty->get_param_ty(i))); - return ret; +void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){ + rt::function* fn = id_fn_map.at(op_id).get(); + (*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); + + // for(size_t n = 0; n < constant_names.size(); n++){ + // const torch::Tensor& x = constant_vals[n]; + // fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size()); } -typedef triton::runtime::options_t options_t; -typedef triton::runtime::options_space_t options_space_t; +pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){ + rt::function* fn = id_fn_map.at(op_id).get(); + auto wrapper = [&grid](const rt::options_t& opt){ + pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); + for(auto x: opt.defines) + if(std::all_of(x.second.begin(), x.second.end(), ::isdigit)) + obj.attr(x.first.c_str()) = std::stoi(x.second); + return grid(*obj.cast()); + }; + rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); + return opt_cache_.at(&kernel->opt); +} + + +void init_superblocking(pybind11::module &m); +void init_launch(pybind11::module &m); PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; // bindings for triton classes pybind11::enum_(m, "arg_type") - .value("int1", rt::INT1_T) - .value("int8", rt::INT8_T) - .value("int16", rt::INT16_T) - .value("int32", rt::INT32_T) - .value("int64", rt::INT64_T) - .value("half", rt::HALF_T) - .value("float", rt::FLOAT_T) + .value("int1" , rt::INT1_T) + .value("int8" , rt::INT8_T) + .value("int16" , rt::INT16_T) + .value("int32" , rt::INT32_T) + .value("int64" , rt::INT64_T) + .value("half" , rt::HALF_T) + .value("float" , rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); pybind11::enum_(m, "asm_mode") - .value("ptx", rt::ASM_NV_PTX) + .value("ptx" , rt::ASM_NV_PTX) .value("sass", rt::ASM_NV_SASS); - pybind11::class_(m, "options") - .def(pybind11::init<>()) - .def("d", &options_t::D) - .def_readwrite("num_warps", &options_t::num_warps) - .def_readwrite("defines" , &options_t::defines); + pybind11::class_(m, "options", pybind11::dynamic_attr()) + .def_readwrite("num_warps", &rt::options_t::num_warps) + .def_readwrite("defines" , &rt::options_t::defines); - pybind11::class_(m, "options_space") + pybind11::class_(m, "options_space") .def(pybind11::init<>()) - .def_readwrite("defines", &options_space_t::defines) - .def_readwrite("num_warps", &options_space_t::num_warps); + .def_readwrite("num_warps", &rt::options_space_t::num_warps) + .def_readwrite("defines" , &rt::options_space_t::defines); // hooks into triton constructs since frameworks may not use pybind11 m.def("get_fn_signature", &get_fn_signature); - m.def("get_fn_asm", &get_fn_asm); + // m.def("get_fn_asm", &get_fn_asm); m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); m.def("delete_fn", &delete_fn); m.def("make_op_id", &make_op_id); m.def("cleanup", &cleanup); - ; + m.def("autotune", &autotune, pybind11::return_value_policy::reference); + m.def("launch_kernel", &launch_kernel); + + init_launch(m); + init_superblocking(m); } diff --git a/python/src/launch.cc b/python/src/launch.cc deleted file mode 100644 index 999d9c595..000000000 --- a/python/src/launch.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments -// as a string constructed with struct.pack in python - -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/tools/bench.hpp" -#include "torch/script.h" -#include "ATen/cuda/CUDAContext.h" -#include -#include - - -namespace rt = triton::runtime; -namespace drv = triton::driver; - -typedef std::pair map_key_t; -extern std::map> id_grid_map; -extern std::map> id_fn_map; -std::shared_ptr host_device; -std::shared_ptr host_context; -std::shared_ptr host_stream; - -int64_t cdiv_sum(torch::Tensor x, int64_t div){ - TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor") - auto _x = x.accessor(); - int64_t ret = 0; - for(size_t i = 0; i < x.size(0); i++) - ret += (_x[i] + div - 1) / div; - return ret; -} - -void init_host_stream() { - if(!host_stream){ - host_device.reset(new drv::host_device()); - host_context.reset(drv::context::create(&*host_device)); - host_stream.reset(drv::stream::create(host_context->backend())); - } -} - -CUstream torch_get_cuda_stream(int64_t dev_id) { - return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); -} - -CUdeviceptr torch_get_cuda_device(int64_t dev_id) { - CUdevice ret; - triton::driver::dispatch::cuDeviceGet(&ret, dev_id); - return ret; -} - -void synchronize(int64_t dev_id) { - if(dev_id == -1){ - init_host_stream(); - host_stream->synchronize(); - } - else{ - triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); - stream.synchronize(); - } -} - -torch::Tensor cuda_empty_like(torch::Tensor x){ - if(x.nbytes() == 0) - return torch::empty_like(x); - void* data; - cudaMalloc(&data, x.nbytes()); - auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options()); - return ret; -} - -void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, - const std::vector& constant_names, const std::vector& constant_vals){ - rt::function* fn = id_fn_map.at({op_id, dev_id}).get(); - for(size_t n = 0; n < constant_names.size(); n++){ - const torch::Tensor& x = constant_vals[n]; - fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size()); - } - if(dev_id == -1){ - init_host_stream(); - (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device); - } - else{ - C10_CUDA_CHECK(cudaSetDevice(dev_id)); - triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); - triton::driver::cu_device device(torch_get_cuda_device(dev_id), false); - (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device); - } -} - - -static auto registry = torch::RegisterOperators() - .op("triton::launch_kernel", &launch_kernel) - .op("triton::cuda_empty_like", &cuda_empty_like) - .op("triton::cdiv_sum", &cdiv_sum) - .op("triton::synchronize", &synchronize); diff --git a/python/src/torch/launch.cc b/python/src/torch/launch.cc new file mode 100644 index 000000000..bd9461e0c --- /dev/null +++ b/python/src/torch/launch.cc @@ -0,0 +1,83 @@ +// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments +// as a string constructed with struct.pack in python + +#include +#include +#include +#include "triton/driver/buffer.h" +#include "triton/driver/stream.h" +#include "triton/runtime/function.h" +#include "triton/tools/bench.hpp" +#include "torch/script.h" +#include "ATen/cuda/CUDAContext.h" +#include +#include + + +namespace rt = triton::runtime; +namespace drv = triton::driver; + +typedef std::pair map_key_t; +extern std::map> id_grid_map; +extern std::map> id_fn_map; +extern std::map> tt_devices; +extern std::map> tt_streams; + + +int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +int64_t largest_pow2_divisor(int64_t a){ + if(a % 8 == 0) return 8; + if(a % 4 == 0) return 4; + if(a % 2 == 0) return 2; + return 1; +} + +int64_t cdiv_sum(torch::Tensor x, int64_t div){ + TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor") + auto _x = x.accessor(); + int64_t ret = 0; + for(size_t i = 0; i < x.size(0); i++) + ret += (_x[i] + div - 1) / div; + return ret; +} + +CUstream torch_get_cuda_stream(int64_t dev_id) { + return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); +} + +CUdeviceptr torch_get_cuda_device(int64_t dev_id) { + CUdevice ret; + triton::driver::dispatch::cuDeviceGet(&ret, dev_id); + return ret; +} + +void synchronize(int64_t dev_id) { + tt_streams[dev_id]->synchronize(); +} + +torch::Tensor cuda_empty_like(torch::Tensor x){ + if(x.nbytes() == 0) + return torch::empty_like(x); + void* data; + cudaMalloc(&data, x.nbytes()); + auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options()); + return ret; +} + +void cuda_set_device(int64_t dev_id) { + if(dev_id >= 0) + C10_CUDA_CHECK(cudaSetDevice(dev_id)); +} + + +void init_launch(pybind11::module &m) { + m.def("cuda_set_device", &cuda_set_device); + m.def("cuda_empty_like", &cuda_empty_like); + m.def("largest_pow2_divisor", &largest_pow2_divisor); + m.def("cdiv", &cdiv); + m.def("cdiv_sum", &cdiv_sum); + m.def("synchronize", &synchronize); +} \ No newline at end of file diff --git a/python/src/torch/superblock.cc b/python/src/torch/superblock.cc new file mode 100644 index 000000000..2243eec79 --- /dev/null +++ b/python/src/torch/superblock.cc @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#ifdef _OPENMP +#include +#endif + +typedef std::vector> ret_t; + +void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){ + size_t H = layout.size(0); + size_t M = layout.size(1); + size_t N = layout.size(2); + torch::Tensor tmp = torch::zeros_like(layout); + auto _tmp = tmp.accessor (); + auto _layout = layout.accessor (); + auto _idx = idx.accessor (); + auto _scratch = scratch.accessor(); + std::vector current(H, 0); + #ifdef _OPENMP + #pragma omp parallel for + #endif + for(size_t h = 0; h < H; h++){ + // surrounding indices + std::vector ii_left(max_width, -1); + std::vector> ii_top(max_width, std::vector(N, -1)); + + for(size_t m = 0; m < M; m++){ + for(size_t n = 0; n < N; n++){ + int v = _layout[h][m][n]; + if(v == 0) + continue; + int n_left= ii_left[max_width-1]; + int m_top = ii_top [max_width-1][n]; + int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0; + int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0; + int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0; + int width = std::min(left, std::min(top, topleft)) + 1; + + // reset width if blocks cannot be + // packed together (i.e., there's a 1 "in the middle") + for(int nn = n_left + 1; nn < n; nn++) + if(ii_top[max_width-1][nn] > ii_top[max_width-1][n]) + width = 1; + _tmp[h][m][n] = width; + + // update n_left ring buffer + for(int k = 0; k < max_width-1; k++) + ii_left[k] = ii_left[k+1]; + ii_left[max_width-1] = n; + + // update ii_top ring buffer + for(int k = 0; k < max_width-1; k++) + ii_top[k][n] = ii_top[k+1][n]; + ii_top[max_width-1][n] = m; + + // block is too small -- skip + if(width != max_width) + continue; + + // retained blocks are set to zeros + for(size_t km = 0; km < max_width; km++) + for(size_t kn = 0; kn < max_width; kn++) + { + int mm = ii_top[km][n]; + int nn = ii_left[kn]; + if(mm < 0 || nn < 0) + continue; + _layout[h][mm][nn] = 0; + _tmp[h][mm][nn] = 0; + _scratch[h][current[h]][0] = (int)h; + _scratch[h][current[h]][1] = (int)mm; + _scratch[h][current[h]][2] = (int)nn; + _scratch[h][current[h]][3] = _idx[h][mm][nn]; + current[h]++; + } + } + } + } + std::vector to_cat; + for(size_t h = 0; h < H; h++) + if(current[h] > 0) + to_cat.push_back(scratch[h].slice(0, 0, current[h])); + if(!to_cat.empty()) + ret.push_back(std::make_tuple(max_width, torch::cat(to_cat))); +} + + +ret_t superblock(torch::Tensor layout, int start_width) { + ret_t ret; + // block index + torch::Tensor idx = torch::zeros_like(layout); + int current = 0; + int64_t H = layout.size(0); + int64_t M = layout.size(1); + int64_t N = layout.size(2); + auto _layout = layout.accessor (); + auto _idx = idx.accessor(); + for(int64_t h = 0; h < H; h++) + for(int64_t m = 0; m < M; m++) + for(int64_t n = 0; n < N; n++){ + if(_layout[h][m][n] == 0) + continue; + _idx[h][m][n] = current++; + } + // scratch memory + torch::Tensor scratch = torch::empty({H, layout.sum().item(), 4}, layout.dtype()); + for(int max_width = start_width; max_width > 0; max_width /= 2) + segment_blocks(layout, idx, scratch, max_width, ret); + return ret; +} + + +void init_superblocking(pybind11::module &m) { + m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication"); +} \ No newline at end of file diff --git a/python/tests/test_blocksparse.py b/python/tests/test_blocksparse.py new file mode 100644 index 000000000..1292e5d96 --- /dev/null +++ b/python/tests/test_blocksparse.py @@ -0,0 +1,50 @@ +import itertools +import torch +import triton as tt +import pytest + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block] + return ret + +def mask_tensor(x, mask, block, value = 0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value + return ret + +@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", + [ + (mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ + for at in [False, True]\ + for bt in [False, True]\ + for block in [16, 32, 64] + ] +) +def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384): + # set seed + torch.random.manual_seed(0) + # create inputs + a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') + b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') + shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE] + layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK)) + # triton result + op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) + ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a + rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b + rc = op(ra, rb) + # torch result + ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a + tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b + ta = ta.transpose(2, 3) if TRANS_A else ta + tb = tb.transpose(2, 3) if TRANS_B else tb + tc = torch.matmul(ta, tb) + tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc + tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc + # compare + rtol, atol = {torch.float32: (1e-4, 1e-5), + torch.float16: (1e-2, 1e-3)}[DTYPE] + assert torch.allclose(rc, tc, rtol=rtol, atol=atol) diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py new file mode 100644 index 000000000..46cabd3d4 --- /dev/null +++ b/python/tests/test_conv.py @@ -0,0 +1,17 @@ +import torch +import triton + + +def test_op(): + torch.manual_seed(0) + DTYPE = torch.float16 + N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3 + pad, stride, = (1, 1), (1, 1) + dilation = (1, 1) + a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5 + b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5 + th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation) + tt_c = triton.ops.conv(a, b, pad, stride) + rtol, atol = {torch.float32: (1e-4, 1e-5), + torch.float16: (1e-2, 1e-3)}[DTYPE] + assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py new file mode 100644 index 000000000..65b7f86a0 --- /dev/null +++ b/python/tests/test_matmul.py @@ -0,0 +1,96 @@ +import pytest +import itertools +import triton as tt +import torch as th + +@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[ + [ + # 1 warp + (16, 16, 16, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, None, None, None, AT, BT, DTYPE), + # 2 warp + (64, 32, 64, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 2, None, None, None, AT, BT, DTYPE), + # 4 warp + (128, 64, 16, 4, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 4, None, None, None, AT, BT, DTYPE), + # 8 warp + (128, 256, 16, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 8, None, None, None, AT, BT, DTYPE), + # variable input + (128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE), + (128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE), + (128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE), + (128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE) + ] + for DTYPE in ['float16'] + for AT in [False, True] + for BT in [False, True] +])) +def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE): + DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE] + th.manual_seed(0) + tt.ops._matmul.kernel = dict() + tt.ops._matmul.TM = [TM] + tt.ops._matmul.TN = [TN] + tt.ops._matmul.TK = [TK] + tt.ops._matmul.num_warps = [NWARP] + if M is None: M = TM + if N is None: N = TN + if K is None: K = TK + a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 + b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 + a = a.t() if AT else a + b = b.t() if BT else b + th_c = th.matmul(a, b) + tt_c = tt.ops.matmul(a, b) + rtol, atol = {th.float32: (1e-4, 1e-5), + th.float16: (1e-2, 1e-3)}[DTYPE] + assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol) + + +def do_bench(fn, flops = 0, warmup = 10, rep = 50): + start_event = th.cuda.Event(enable_timing=True) + end_event = th.cuda.Event(enable_timing=True) + ret = fn() + for i in range(warmup): + fn() + th.cuda.synchronize() + start_event.record() + for i in range(rep): + fn() + end_event.record() + th.cuda.synchronize() + time_ms = start_event.elapsed_time(end_event) / rep + return time_ms, flops/time_ms*1e-9, ret + + +def perf_op(dtype=th.float16, warmup=10, rep=50): + AT, BT = False, False + configs = [(N, N, N) for N in [128, 8192]] + for M, N, K in configs: + a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 + b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 + if AT: a = a.t() + if BT: b = b.t() + a = a[::,::] + b = b[::,::] + TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) + TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) + print((M, N, K), TH_MS, TT_MS) \ No newline at end of file diff --git a/python/tests/test_softmax.py b/python/tests/test_softmax.py new file mode 100644 index 000000000..7ac75c5af --- /dev/null +++ b/python/tests/test_softmax.py @@ -0,0 +1,8 @@ +import torch +import triton + +def test_op(M = 1024, N = 1024, dtype = torch.float32): + x = torch.randn(M, N, dtype=dtype, device='cuda') + th_y = torch.softmax(x, dim=-1) + tt_y = triton.ops.softmax(x) + assert torch.allclose(tt_y, th_y) \ No newline at end of file diff --git a/python/triton/__init__.py b/python/triton/__init__.py index a3c74c5f3..4d1fa4ba9 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,8 +1,13 @@ -from .kernel import * +# TODO: torch needs to be imported first +# or pybind11 shows `munmap_chunk(): invalid pointer` +import torch -# clean-up libtriton resources +# libtriton resources import atexit import triton._C.libtriton as libtriton @atexit.register def cleanup(): - libtriton.cleanup() \ No newline at end of file + libtriton.cleanup() + +from .kernel import * +from . import ops \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 83ea60a58..012642245 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -15,18 +15,6 @@ codes = { libtriton.arg_type.buffer: 'P' } -sizes = { - libtriton.arg_type.int1: 1, - libtriton.arg_type.int8: 1, - libtriton.arg_type.int32: 4, - libtriton.arg_type.int64: 8, - libtriton.arg_type.half: 2, - libtriton.arg_type.float: 4, - libtriton.arg_type.double: 8, - libtriton.arg_type.buffer: 8 -} - - def th_to_triton(obj): tys = { torch.int8: 'char', @@ -43,92 +31,65 @@ def th_to_triton(obj): return [th_to_triton(x)[0] for x in obj] return [str(obj)] - def cdiv(a, b): - return (a + b - 1) // b + return libtriton.cdiv(a, b) -def cdiv_sum(a, b): - return torch.ops.triton.cdiv_sum(a, b) - def synchronize(device): dev_id = device.index dev_id = -1 if dev_id is None else dev_id - torch.ops.triton.synchronize(dev_id) + libtriton.synchronize(dev_id) + +def read(path): + with open(path, 'r') as f: + source = f.read() + return source + class kernel: - def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): + def __init__(self, src, device, defines = dict(), num_warps = [4]): self.src = src self.opt = libtriton.options_space() self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()] self.opt.num_warps = num_warps + # device + assert device.type in ['cuda', 'cpu'] + if device.type == 'cuda': + self.device = torch.cuda.current_device() if device.index is None else device.index + if device.type == 'cpu': + self.device = -1 + # C++ function wrapper self.op_id = libtriton.make_op_id() - self.registered = set() - arg_types = libtriton.get_fn_signature(self.src, self.opt) - size = sum([sizes[x] for x in arg_types]) + libtriton.register_fn(self.op_id, self.device, self.src, self.opt) + # debug mode + self.is_debug = 'TRITON_DEBUG' in os.environ + # signature + arg_types = libtriton.get_fn_signature(self.op_id) self.tys = ''.join([codes[x] for x in arg_types]) - def asm(self, mode, device, **kwargs): - dev_id = device.index - # assembly mode - supported = { - 'ptx': libtriton.asm_mode.ptx, - 'sass': libtriton.asm_mode.sass, - } - if mode not in supported: - raise('ASM mode must be in ', supported.keys()) - mode = supported[mode] - # disambiguates #defines - libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) - def _single_value_or_err(x, key): - if isinstance(x, list) and len(x) == 1: - return x[0] - if isinstance(x, list) and len(x) > 1: - if key in kwargs: - return kwargs[key] - raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. ' - 'Please supply an explicit value as a keyword argument.') - return str(x) - defines = dict() - for (D, V) in self.opt.defines: - defines[D] = _single_value_or_err(V, D) - opt = libtriton.options() - opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps') - opt.defines = defines - # run - return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt) - - def __call__(self, *args, **kwargs): - if 'TRITON_DEBUG_MODE' in os.environ: + def __call__(self, *args, grid): + # debug mode (initialize) + if self.is_debug: _args = args args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args] for i in range(len(args)): if isinstance(args[i], torch.Tensor): - args[i] = torch.ops.triton.cuda_empty_like(args[i]) + args[i] = libtriton.cuda_empty_like(args[i]) args[i].copy_(_args[i]) - torch.cuda.synchronize() - for x in args: - if isinstance(x, torch.Tensor): - device = x.device.index - device = -1 if device is None else device - break - # lazily register function for device - libtriton.register_fn((self.op_id, device), self.src, self.opt) - # launch grid - if 'grid' not in kwargs: - raise RuntimeError('Must provide grid for kernel launch') - grid = kwargs['grid'] - libtriton.register_grid((self.op_id, device), grid) - # re-allocate buffers for auto-tuning - if 'autotune_buf' in kwargs: - pass - # launch - params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) - names = list(kwargs['constants'].keys()) if 'constants' in kwargs else [] - constants = list(kwargs['constants'].values()) if 'constants' in kwargs else [] - torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) - if 'TRITON_DEBUG_MODE' in os.environ: - torch.cuda.synchronize() + # initialize cuda device if necessary + libtriton.cuda_set_device(self.device) + # pack parameters into a byte buffer + params = pack(self.tys, *args) + # auto-tune if necessary + opt = libtriton.autotune(self.op_id, self.device, params, grid) + # run kernel + grid = grid(opt) + grid_0 = grid[0] + grid_1 = 1 if len(grid) < 2 else grid[1] + grid_2 = 1 if len(grid) < 3 else grid[2] + libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) + # debug mode (finalize) + if self.is_debug: for i in range(len(args)): if isinstance(args[i], torch.Tensor): _args[i].copy_(args[i].clone()) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py new file mode 100644 index 000000000..b2426c678 --- /dev/null +++ b/python/triton/ops/__init__.py @@ -0,0 +1,4 @@ +from .conv import _conv, conv +from .matmul import _matmul, matmul +from .softmax import _softmax, softmax +from . import blocksparse \ No newline at end of file diff --git a/python/triton/ops/blocksparse/__init__.py b/python/triton/ops/blocksparse/__init__.py new file mode 100644 index 000000000..64893145f --- /dev/null +++ b/python/triton/ops/blocksparse/__init__.py @@ -0,0 +1 @@ +from .matmul import matmul \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.c b/python/triton/ops/blocksparse/matmul.c new file mode 100644 index 000000000..e520b81e6 --- /dev/null +++ b/python/triton/ops/blocksparse/matmul.c @@ -0,0 +1,198 @@ + __global__ void NAME (TYPE* A __readonly __noalias __aligned(16), + TYPE* B __readonly __noalias __aligned(16), + TYPE* C __noalias __aligned(16), + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc __multipleof(8), + long stride_za __multipleof(8), + long stride_zb __multipleof(8), + long stride_zc __multipleof(8), + long stride_ha __multipleof(8), + long stride_hb __multipleof(8), + long stride_hc __multipleof(8), + int DS0, int DS1, + int SDD_K __multipleof(16), + int SDD_off_width, + int* lut, int* locks, int nlocks) { + /* ---------------- */ + /* Prologue */ + /* ---------------- */ + // program ids + int pid0 = get_program_id(0); + int pid1 = get_program_id(1); + int pidz = get_program_id(2); +#ifdef SDD + // load LUT header + pid1 = pid1 + SDD_off_width; + int blockidm[TM] = (0 ... TM) / BLOCK; + int blockidn[TN] = (0 ... TN) / BLOCK; + int offlutm[TM] = blockidm*(TN/BLOCK)*4; + int offlutn[TN] = blockidn*4; + int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4; + int z = *(header + 0); + int i[TM] = *(header + 1 + offlutm); + int j[TN] = *(header + 2 + offlutn); + int AS1 = SDD_K / TZ; + int lockid = select(TZ > 1, 1, 0); + int offka = pid0 * AS1; + int offkb = pid0 * AS1; + int offmc = 0; + int offnc = 0; + int offpa = 0; + int offpb = 0; + int maxid = TZ; + int offhc = 0; + int offha = z; + int offhb = z; + int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK); + int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK); +#else + // load LUT header + int *header = lut + pid0 * 6; + int offset = *(header + 0); + int AS1 = *(header + 1); + int column = *(header + 2); + int depth = *(header + 3); + int lockid = *(header + 4); + int maxid = *(header + 5); + int *pinc = lut + offset; + int offhc = depth; +#ifdef DSD + // output offset + int offnc = pid1 * TN; + int offmc = column * TM; + int offpc = 0; + // dense input offset + int offnb = pid1 * TN; + int offkb __multipleof(8) = *pinc; + int offpb = 0; + // sparse input offset + int offma = 0; + int offka = 0; + long offpa __multipleof(8) = *(pinc + 1); + offpa = offpa * BLOCK * BLOCK; + int offha = 0; + int offhb = depth; +#endif +#ifdef DDS + // output offset + int offmc = pid1 * TM; + int offnc = column * TN; + int offpc = 0; + // dense input offset + int offma = pid1 * TM; + int offka __multipleof(8) = *pinc; + int offpa = 0; + // sparse input offset + int offnb = 0; + int offkb = 0; + long offpb __multipleof(8) = *(pinc + 1); + offpb = offpb * BLOCK * BLOCK; + int offha = depth; + int offhb = 0; +#endif + int ram[TM] = offma + 0 ... TM; + int rbn[TN] = offnb + 0 ... TN; +#endif + // initialize a, b pointers + int rka[TK] = offka + 0 ... TK; + int rkb[TK] = offkb + 0 ... TK; + TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK; + TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; + // pre-fetch +#ifdef DDS + bool checkam[TM, TK] = ram[:, newaxis] < DS0; +#else + bool checkam[TM, TK] = AS1 > 0; +#endif +#ifdef DSD + bool checkbn[TK, TN] = rbn[newaxis, :] < DS0; +#else + bool checkbn[TK, TN] = AS1 > 0; +#endif + TYPE a[TM, TK] = checkam ? *pa : 0; + TYPE b[TK, TN] = checkbn ? *pb : 0; + + /* ---------------- */ + /* Inner Loop */ + /* ---------------- */ + // create result tile + float acc[TM, TN] = 0; + int step = TK; + for(int k = AS1; k > 0; k -= step) { + acc += a @ b; + // update pointers +#ifdef SDD + int inc_a = TK * STRIDE_AK; + int inc_b = TK * STRIDE_BK; +#else + pinc += 2; +#ifdef DSD + int inc_b __multipleof(8) = *pinc; + int inc_a __multipleof(8) = *(pinc + 1); + inc_b = inc_b * STRIDE_BK; +#endif +#ifdef DDS + int inc_a __multipleof(8) = *pinc; + int inc_b __multipleof(8) = *(pinc + 1); + inc_a = inc_a * STRIDE_AK; +#endif +#endif + pa += inc_a; + pb += inc_b; + // pre-fetch + bool checkak[TM, TK] = k > TK; + bool checkbk[TK, TN] = k > TK; + bool checka[TM, TK] = checkam && checkak; + bool checkb[TK, TN] = checkbk && checkbn; + a = *?(checka)pa; + b = *?(checkb)pb; + } + TYPE c[TM, TN] = acc; + + /* ---------------- */ + /* Epilogue */ + /* ---------------- */ + // initialize c pointers +#ifdef SDD + bool checkc[TM, TN] = 1; + // rematerialize + int rr_blockidm[TM] = (0 ... TM) / BLOCK; + int rr_blockidn[TN] = (0 ... TN) / BLOCK; + int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4; + int rr_offlutn[TN] = rr_blockidn*4; + int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :]; + int bkid[TM, TN] = *(header + off_bkid); + long offpc[TM, TN] = bkid * BLOCK * BLOCK; + // range within blocks + int rcm[TM] = (0 ... TM) % BLOCK; + int rcn[TN] = (0 ... TN) % BLOCK; +#else + int rcm[TM] = offmc + 0 ... TM; + int rcn[TN] = offnc + 0 ... TN; +#ifdef DSD + bool checkc[TM, TN] = rcn[newaxis, :] < DS0; +#endif +#ifdef DDS + bool checkc[TM, TN] = rcm[:, newaxis] < DS0; +#endif +#endif + TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN; + // write-back directly + if(lockid == 0) { + *?(checkc) pc = c; + } + // accumulate partial result using spin-locks + else { + int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1; + int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks; + for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc) pc = c; + else + *?(checkc) pc = c + *?(checkc)pc; + atomic_xchg(pcount, (count + 1) % maxid); + atomic_xchg(plock, 0); + } + } \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py new file mode 100644 index 000000000..a3ae38d7e --- /dev/null +++ b/python/triton/ops/blocksparse/matmul.py @@ -0,0 +1,467 @@ +import triton +import triton._C.libtriton as libtriton +import torch +import os +import math + +src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) + +############## +# MAIN API # +############## +class _matmul(torch.autograd.Function): + + sdd_cache = dict() + dsd_cache = dict() + dds_cache = dict() + locks = dict() + + # Given an array sizes representing reduction size for each + # column of a block-mode matrix multiplication, + # performs load-balancing to achieve more smaller reductions + # between `seg_size` elements + @staticmethod + def load_balance(sizes, block): + # segment size + # heuristics taken from OpenAI blocksparse code + # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 + max_size = sizes.max() + min_size = sizes[sizes != 0].min() + #if max_size > min_size * 2.0: + # seg_max = max(triton.cdiv(max_size, 4), min_size*2) + #else: + # seg_max = max_size + seg_max = max_size + seg_min = max(triton.cdiv(seg_max, 4), 4) + # split reduction into segments + div = sizes // seg_max + rem = sizes % seg_max + packs = div + (sizes < seg_min).long() + (rem >= seg_min).long() + width = packs.sum() + segments = torch.empty(width, dtype=sizes.dtype) + column = torch.empty_like(segments) + lockid = torch.zeros_like(segments) + maxid = torch.zeros_like(segments) + nlocks = 0 + current = 0 + col_idx = 0 + for i in range(len(sizes)): + d, r = div[i], rem[i] + isempty = sizes[i] < seg_min + last = current + d + (r >= seg_min) + isempty + # column id + column[current:last] = col_idx + # lock id + if d > 1 or (d == 1 and r >= seg_min): + nlocks += 1 + lockid[current:last] = nlocks + maxid[current:last] = last - current + # segment size + segments[current:current+d] = seg_max + if r < seg_min and not isempty: + segments[current+d-1] += r + if r >= seg_min or isempty: + segments[current+d] = r + current = last + col_idx += 1 + offsets = torch.zeros_like(segments) + offsets[1:] = torch.cumsum(segments[:-1], dim=0) + return segments, column, lockid, maxid, offsets + + @staticmethod + def get_locks(size, dev): + if dev not in _matmul.locks or \ + size > _matmul.locks[dev].size(0): + _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev) + return _matmul.locks[dev] + + ########################## + # SPARSE = DENSE x DENSE # + ########################## + + @staticmethod + def make_sdd_lut(layout, block, dtype, device): + start_width = 64 // block + superblocks = libtriton.superblock(layout.type(torch.int32), start_width) + luts, widths, packs = [], [], [] + for size, nnz in superblocks: + width = nnz.shape[0] // (size*size) + h = nnz[:, 0] + i = nnz[:, 1] + j = nnz[:, 2] + b = nnz[:, 3] + lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous() + luts.append(lut.type(torch.int32).to(device)) + widths.append(width) + packs.append(size) + # create locks + return luts, None, widths, packs + + @staticmethod + def _sdd_matmul(a, b, trans_a, trans_b, trans_c, + spdims, block, luts, num_locks, widths, packs): + + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + AS0 = a.size(0) + AS1 = a.size(1) + AS2 = a.size(3 if trans_a else 2) + AS3 = a.size(2 if trans_a else 3) + BS0 = b.size(0) + BS1 = b.size(1) + BS2 = b.size(3 if trans_b else 2) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + device = a.device + is_16_multiple = AS3 % 16 == 0 + is_32_multiple = AS3 % 32 == 0 + is_64_multiple = AS3 % 64 == 0 + if not is_16_multiple: + raise ValueError('Reduction size for SDD must be a multiple of 16') + # create kernel + total_width = sum([width*pack*pack for width,pack in zip(widths, packs)]) + c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device) + for lut, width, pack in zip(luts, widths, packs): + num_lock = 1 + key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) + if key not in _matmul.sdd_cache: + F32TK = [8, 16] + #F16TK = [16] + #F16TK += [32] if is_32_multiple else [] + #F16TK += [64] if is_64_multiple else [] + F16TK = [64] + TK = {torch.float32: F32TK, + torch.float16: F16TK}[dtype] + defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block, + 'TK': TK, 'TYPE': dtype, + 'STRIDE_AM': '1' if trans_a else 'lda', + 'STRIDE_AK': 'lda' if trans_a else '1', + 'STRIDE_BN': 'ldb' if trans_b else '1', + 'STRIDE_BK': '1' if trans_b else 'ldb', + 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', + 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'} + _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4]) + + kernel = _matmul.sdd_cache[key] + # create output + locks = _matmul.get_locks(2*width*AS0*num_lock, a.device) + # maximum grid size is 65535 + # so operation might be decomposed into multiple + # kernel calls + max_width = 49152 + for off_width in range(0, width, max_width): + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), + a.stride(2), b.stride(2), block, + a.stride(0), b.stride(0), c.stride(0), + a.stride(1), b.stride(1), c.stride(0), + AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock, + grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) + # save for backward pass + return c + + ########################## + # DENSE = DENSE x SPARSE # + # DENSE = SPARSE x DENSE # + ########################## + + # Given a binary layout of 0s and 1s, + # Construct look-up table for efficient execution on GPUs + @staticmethod + def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx): + # load-balancing + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + segments = _empty.clone() + column = _empty.clone() + depth = _empty.clone() + lockid = _empty.clone() + maxid = _empty.clone() + offsets = _empty.clone() + current_offset = 0 + current_maxid = 0 + for z in range(layout.size(0)): + if trans: + sizes = torch.sum(layout[z, :, :], 1) + else: + sizes = torch.sum(layout[z, :, :], 0) + z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block) + z_depth = z * torch.ones_like(z_segments) + z_lockid[z_lockid > 0] += current_maxid + current_maxid = z_lockid.max() + # concatenate depth + segments = torch.cat((segments, z_segments)) + column = torch.cat((column, z_column)) + depth = torch.cat((depth, z_depth)) + maxid = torch.cat((maxid, z_maxid)) + offsets = torch.cat((offsets, current_offset + z_offsets)) + lockid = torch.cat((lockid, z_lockid)) + current_offset += layout[z, :, :].sum() + segments *= step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets)) + idx = transform(nnz[:, 2]*block) + xincs = idx.clone() + xincs[1:] -= idx[:-1] + # divide block into multiple steps + div = block // step + xincs = xincs.view(-1, 1).repeat(1, div) + xincs[:, 1:] = step + xincs[:, 0 ] -= (div-1)*step + # first increment for each reduction is actually the offset + xincs[offsets[segments>0], 0] = idx[offsets[segments>0]] + xincs = xincs.view(-1) + # block-mode input increments + if trans: + widx = torch.arange(num_blocks) + else: + widx = _empty.clone() + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum) + widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + widx = widx + wincs = widx*block*block + wincs[1:] -= widx[:-1]*block*block + wincs = wincs.view(-1, 1).repeat(1, div) + if trans: + wincs[:, 1:] = step + wincs[:, 0] -= (div-1)*step + else: + wincs[:, 1:] = step*block + wincs[:, 0] -= (div - 1)*step*block + wincs[offsets[segments>0], 0] = widx[offsets[segments>0]] + wincs = wincs.view(-1) + # adjust offset and segment size + offsets *= 2*div + segments *= div + # create header + width = column.size(0) + offsets += 6*width + header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous() + incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous() + incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + num_locks = max(1, lockid.max()) + return lut, num_locks, width, None + + @staticmethod + def _dds_matmul(a, b, trans_a, trans_b, trans_c, + spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS0 = a.size(0) + AS1 = a.size(1) + AS2 = a.size(3 if trans_a else 2) + AS3 = a.size(2 if trans_a else 3) + BS0 = spdims[0] + BS1 = block * spdims[2 if trans_b else 1] + BS2 = block * spdims[1 if trans_b else 2] + dtype = a.dtype + # kernel + key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) + if key not in _matmul.dds_cache: + TM = [64, 128] if dtype == torch.float32 else [64, 128, 256] + TK = [8] if dtype == torch.float32 else [16] + defines = {'TM': TM, 'TN': block, 'TK': TK, + 'BLOCK': block, + 'TYPE': dtype, + 'STRIDE_AM': 1 if trans_a else 'lda', + 'STRIDE_AK': 'lda' if trans_a else 1, + 'STRIDE_BN': block if trans_b else 1, + 'STRIDE_BK': 1 if trans_b else block, + 'STRIDE_CM': '1' if trans_c else 'ldc', + 'STRIDE_CN': 'ldc' if trans_c else '1', + 'NAME': 'dds_kernel', + 'DDS': True} + _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4]) + kernel = _matmul.dds_cache[key] + # output + CS0 = AS0 + CS1 = AS1 + CS2 = BS2 if trans_c else AS2 + CS3 = AS2 if trans_c else BS2 + locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device) + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), + a.stride(2), block, c.stride(2), + a.stride(0), b.stride(0), c.stride(0), + a.stride(1), b.stride(1), c.stride(1), + AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, + grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]) + return c + + @staticmethod + def _dsd_matmul(a, b, trans_a, trans_b, trans_c, + spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS0 = spdims[0] + AS1 = block * spdims[2 if trans_a else 1] + AS2 = block * spdims[1 if trans_a else 2] + BS0 = b.size(0) + BS1 = b.size(1) + BS2 = b.size(3 if trans_b else 2) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # kernel + key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) + if key not in _matmul.dsd_cache: + TN = [64, 128] if dtype == torch.float32 else [64, 128] + TK = [8] if dtype == torch.float32 else [16] + defines = {'TM': block, 'TN': TN, 'TK': TK, + 'BLOCK': block, + 'TYPE': dtype, + 'STRIDE_AM': 1 if trans_a else block, + 'STRIDE_AK': block if trans_a else 1, + 'STRIDE_BN': 'ldb' if trans_b else '1', + 'STRIDE_BK': '1' if trans_b else 'ldb', + 'STRIDE_CM': '1' if trans_c else 'ldc', + 'STRIDE_CN': 'ldc' if trans_c else '1', + 'NAME': 'dsd_kernel', + 'DSD': True} + _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4]) + kernel = _matmul.dsd_cache[key] + # output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device) + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), + block, b.stride(2), c.stride(2), + a.stride(0), b.stride(0), c.stride(0), + a.stride(1), b.stride(1), c.stride(1), + BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, + grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) + return c + + fn = {'sdd': _sdd_matmul.__get__(object), + 'dsd': _dsd_matmul.__get__(object), + 'dds': _dds_matmul.__get__(object)} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, + mode, spdims, block, + c_lut, c_num_locks, c_width, c_packs, + da_lut, da_num_locks, da_width, da_packs, + db_lut, db_num_locks, db_width, db_packs): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, + c_lut, c_num_locks, c_width, c_packs) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_num_locks = da_num_locks + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.da_packs = da_packs + ctx.db_lut = db_lut + ctx.db_num_locks = db_num_locks + ctx.db_width = db_width + ctx.db_packs = db_packs + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs) + return da, db, None, None, None,\ + None, None, None, None,\ + None, None, None, None, None, None,\ + None, None, None, None, None, None,\ + None, None, None, None, None, None + +class matmul: + + def make_lut(self, dtype, device): + key = (dtype, device) + if key in self.lut_cache: + return self.lut_cache[key] + # C look-up table + layout, block = self.layout, self.block + step = 8 if dtype == torch.float32 else 16 + if self.mode == 'sdd': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + elif self.mode == 'dsd': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) + elif self.mode == 'dds': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) + # DA look-up table + if self.mode == 'sdd': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) + elif self.mode == 'dsd': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + elif self.mode == 'dds': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) + # DB look-up table + if self.mode == 'sdd': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) + elif self.mode == 'dsd': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) + elif self.mode == 'dds': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ + da_lut, da_num_locks, da_width, da_packs,\ + db_lut, db_num_locks, db_width, db_packs) + return self.lut_cache[key] + + def __init__(self, layout, block, mode, trans_a = False, trans_b = False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + # look-up table cache + self.lut_cache = dict() + # attributes + self.trans_a = trans_a + self.trans_b = trans_b + self.mode = mode + self.spdims = layout.shape + self.block = block + self.layout = layout + + # pad shapes of a tensor to make it + # compatible with kernel calls + @staticmethod + def _pad_shape(x, is_sparse): + max_dim = 3 if is_sparse else 4 + for i in range(max_dim - x.dim()): + x = x.unsqueeze(0) + return x + + def __call__(self, a, b): + c_lut, c_num_locks, c_width, c_packs,\ + da_lut, da_num_locks, da_width, da_packs,\ + db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) + # pad shapes with ones + a = matmul._pad_shape(a, self.mode == 'dsd') + b = matmul._pad_shape(b, self.mode == 'dds') + # execute + c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, + self.mode, self.spdims, self.block, + c_lut, c_num_locks, c_width, c_packs, + da_lut, da_num_locks, da_width, da_packs, + db_lut, db_num_locks, db_width, db_packs) + return c \ No newline at end of file diff --git a/python/examples/tutorials/conv2d.py b/python/triton/ops/conv.c similarity index 63% rename from python/examples/tutorials/conv2d.py rename to python/triton/ops/conv.c index c30009cdf..d115ff540 100644 --- a/python/examples/tutorials/conv2d.py +++ b/python/triton/ops/conv.c @@ -1,16 +1,9 @@ -import torch -import triton - -class _conv(torch.autograd.Function): - src = """ - __global__ void conv(TYPE *A __noalias __readonly __aligned(16), +__global__ void conv(TYPE *A __noalias __readonly __aligned(16), TYPE *B __noalias __readonly __aligned(16), TYPE *C __noalias __aligned(16), float alpha, // equivalent matmul - int M __retune, - int N __retune, - int K __retune, + int M, int N, int K, // convolution properties int pad_h, int pad_w, int stride_h, int stride_w, // pointer increment @@ -130,73 +123,4 @@ class _conv(torch.autograd.Function): atomic_xchg(pcount, (count + 1) % TZ); atomic_xchg(plock, 0); #endif -} - """ - - kernel = dict() - - @staticmethod - def unpack(IDX, CI, R, S): - s = IDX % S - cr = IDX // S - r = cr % R - ci = cr // R - return ci, r, s - - @staticmethod - def forward(ctx, a, b, pad, stride, time): - # create kernel if necessary - dtype = a.dtype - # shapes - Z, CI, H, W = a.shape - _, R, S, CO = b.shape - P = (H + 2*pad[0] - R)//stride[0] + 1 - Q = (W + 2*pad[1] - S)//stride[1] + 1 - # compile kernel - if dtype not in _conv.kernel: - TK = 8 - defines = { - 'TYPE' : dtype, - 'TM' : [16, 32, 64, 128], - 'TN' : [16, 32, 64, 128], - 'TK' : [TK], - 'TZ' : [1], - 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, - } - idx = torch.arange(CI*R*S) - ci, r, s = _conv.unpack(idx, CI, R, S) - nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) - delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) - delta = delta.type(torch.int32).cuda() - _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines)) - delta, kernel = _conv.kernel[dtype] - # allocate output - c = torch.empty([Z, CO, P, Q], dtype=dtype) - # enqueue - grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')), - triton.cdiv(CO, opt.d('TN'))] - time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S, - pad[0], pad[1], stride[0], stride[1], - delta, - a.stride(0), a.stride(1), a.stride(2), a.stride(3), - b.stride(0), b.stride(1), b.stride(2), b.stride(3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - grid=grid, bench=100) - return c - - - -conv = _conv.apply -torch.manual_seed(0) -Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3 -pad = (1, 1) -stride = (1, 1) -a = torch.rand((Z, CI, H, W)).cuda() -b = torch.rand((CI, R, S, CO)).cuda() -time = [None] -cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1]) -c = conv(a, b, pad, stride, time) -print((cc - c).abs().max() / max(cc.max(), c.max())) -print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12) -#zc = torch.matmul(a,b) -#zc_ = dot(a,b) +} \ No newline at end of file diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py new file mode 100644 index 000000000..4e2961898 --- /dev/null +++ b/python/triton/ops/conv.py @@ -0,0 +1,57 @@ +import torch +import triton +import os + +class _conv(torch.autograd.Function): + src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c')) + kernel = dict() + + @staticmethod + def unpack(IDX, CI, R, S): + s = IDX % S + cr = IDX // S + r = cr % R + ci = cr // R + return ci, r, s + + @staticmethod + def forward(ctx, a, b, pad, stride): + # create kernel if necessary + dtype = a.dtype + device = a.device + # shapes + Z, CI, H, W = a.shape + _, R, S, CO = b.shape + P = (H + 2*pad[0] - R)//stride[0] + 1 + Q = (W + 2*pad[1] - S)//stride[1] + 1 + # compile kernel + if (dtype, device) not in _conv.kernel: + TK = 16 + defines = { + 'TYPE' : dtype, + 'TM' : [32, 64, 128], + 'TN' : [32, 64, 128], + 'TK' : [TK], + 'TZ' : [1], + 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, + } + idx = torch.arange(CI*R*S) + ci, r, s = _conv.unpack(idx, CI, R, S) + nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) + delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) + delta = delta.type(torch.int32).cuda() + _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines)) + delta, kernel = _conv.kernel[dtype] + # allocate output + c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) + # enqueue + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S, + pad[0], pad[1], stride[0], stride[1], + delta.data_ptr(), + a.stride(0), a.stride(1), a.stride(2), a.stride(3), + b.stride(0), b.stride(1), b.stride(2), b.stride(3), + c.stride(0), c.stride(1), c.stride(2), c.stride(3), + grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)]) + return c + +conv = _conv.apply \ No newline at end of file diff --git a/tests/common/src/dot.h b/python/triton/ops/matmul.c similarity index 62% rename from tests/common/src/dot.h rename to python/triton/ops/matmul.c index 5973fe762..95f21fde2 100644 --- a/tests/common/src/dot.h +++ b/python/triton/ops/matmul.c @@ -1,26 +1,24 @@ -namespace src { - - const char *dot = -R"( #define STM 8 #define STN 8 -__global__ void dot(TYPE * A __noalias __readonly __aligned(16), - TYPE * B __noalias __readonly __aligned(16), - TYPE * C __noalias __aligned(16), - float alpha, - int M __retune, - int N __retune, - int K __retune __multipleof(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8), - int* locks) { +__global__ void matmul(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __aligned(16), + float alpha, + int M, + int N, + int K __multipleof(16), + int lda __multipleof(LDA_POW2_DIV), + int ldb __multipleof(LDB_POW2_DIV), + int ldc __multipleof(LDC_POW2_DIV), + int* locks) { // prologue int pid = get_program_id(0); int pidz = get_program_id(2); int gridm = (M + TM - 1) / TM; int gridn = (N + TN - 1) / TN; + + // swizzle for better L2 performance int width = STM*gridn; int stm = pid / width; int RSTM = min(gridm - stm*STM, STM); @@ -34,14 +32,15 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), int rm[TM] = pidm * TM + 0 ... TM; int rn[TN] = pidn * TN + 0 ... TN; - // reduction splitting + // split-k for better parrallelism K = K / TZ; - int rk[TK] = pidz * K + 0 ... TK; + int rk[TK] = 0 ... TK; // pointers to operands - int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; - int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; + int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; + int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; TYPE* pa[TM, TK] = A + offa; TYPE* pb[TK, TN] = B + offb; + // prefetches operands bool checka[TM, TK] = rk[newaxis, :] < K; bool checkb[TK, TN] = rk[:, newaxis] < K; @@ -49,14 +48,25 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE b[TK, TN] = checkb ? *pb : 0; pa += TK * STRIDE_AK; pb += TK * STRIDE_BK; + // reduction loop float acc[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ - bool checka[TM, TK] = k > TK; - bool checkb[TK, TN] = k > TK; +#if (IS_TK_DIV_K==1) + bool checkk[TK] = k > TK; +#else + bool checkk[TK] = rk < k - TK; +#endif + bool checka[TM, TK] = checkk[newaxis, :]; + bool checkb[TK, TN] = checkk[:, newaxis]; acc += a @ b; +#if (IS_TK_DIV_K==1) a = *?(checka)pa; b = *?(checkb)pb; +#else + a = checka ? *pa : 0; + b = checkb ? *pb : 0; +#endif pa += TK * STRIDE_AK; pb += TK * STRIDE_BK; } @@ -68,8 +78,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), int rcn[TN] = pidn * TN + 0 ... TN; int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rcm[:, newaxis] < M && - rcn[newaxis, :] < N; + bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N; #if (TZ==1) *?(checkc) pc = c; #else @@ -85,7 +94,4 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), atomic_xchg(pcount, (count + 1) % TZ); atomic_xchg(plock, 0); #endif -} -)"; - -} +} \ No newline at end of file diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py new file mode 100644 index 000000000..f3a1348e0 --- /dev/null +++ b/python/triton/ops/matmul.py @@ -0,0 +1,80 @@ +import torch +import triton +import os + +class _matmul(torch.autograd.Function): + src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) + + TM = [128] + TN = [128] + TK = [32] + TZ = 1 + num_warps = [4] + + @staticmethod + def largest_pow2_divisor(N): + if N % 8 == 0: return 8 + if N % 4 == 0: return 4 + if N % 2 == 0: return 2 + return 1 + + + _locks = dict() + _kernels = dict() + @staticmethod + def _call(a, b): + dtype = a.dtype + device = a.device + # allocate output + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), dtype=dtype, device=device) + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() + # kernel hash + is_a_row = a.stride(1) == 1 + is_b_row = b.stride(1) == 1 + lda = a.stride(0) if is_a_row else a.stride(1) + ldb = b.stride(0) if is_b_row else b.stride(1) + ldc = c.stride(0) + lda_pow2_div = _matmul.largest_pow2_divisor(lda) + ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) + ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) + is_tk_div_k = K % 32 == 0 + key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) + if key not in _matmul._kernels: + defines = { + 'TYPE' : dtype, + 'STRIDE_AM' : 'lda' if is_a_row else '1', + 'STRIDE_AK' : '1' if is_a_row else 'lda', + 'STRIDE_BK' : 'ldb' if is_b_row else '1', + 'STRIDE_BN' : '1' if is_b_row else 'ldb', + 'LDA_POW2_DIV': lda_pow2_div, + 'LDB_POW2_DIV': ldb_pow2_div, + 'LDC_POW2_DIV': ldc_pow2_div, + 'TM' : _matmul.TM, + 'TN' : _matmul.TN, + 'TK' : _matmul.TK, + 'TZ' : _matmul.TZ, + 'IS_TK_DIV_K' : is_tk_div_k + } + _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines) + kernel = _matmul._kernels[key] + # # locks for split-k + if device not in _matmul._locks: + _matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device) + locks = _matmul._locks[device] + # enqueue + alpha = 1. + args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] + grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1] + kernel(*args, grid=grid) + return c + + @staticmethod + def forward(ctx, a, b): + c = _matmul._call(a,b) + return c + +matmul = _matmul.apply diff --git a/python/triton/ops/softmax.c b/python/triton/ops/softmax.c new file mode 100644 index 000000000..e363901c6 --- /dev/null +++ b/python/triton/ops/softmax.c @@ -0,0 +1,8 @@ +__global__ void forward(TYPE* X, TYPE* Y) { + int pid = get_program_id(0); + int off[BLOCK] = pid * BLOCK + 0 ... BLOCK; + float x[BLOCK] = *(X + off); + float shifted[BLOCK] = exp(x - x[max]); + float sum = shifted[+]; + *(Y + off) = shifted / sum; +} \ No newline at end of file diff --git a/python/triton/ops/softmax.py b/python/triton/ops/softmax.py new file mode 100644 index 000000000..516ea8cc8 --- /dev/null +++ b/python/triton/ops/softmax.py @@ -0,0 +1,27 @@ +import torch +import triton +import os + +kernels = dict() +def get_kernel(block, dtype, device): + key = (block, dtype, device) + if key not in kernels: + src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c')) + defines = {'BLOCK': block, 'TYPE': dtype} + kernels[key] = triton.kernel(src, device = device, defines = defines) + return kernels[key] + + +class _softmax(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + y = torch.empty_like(x) + M, N = x.shape + kernel = get_kernel(N, x.dtype, x.device) + kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ]) + return y + +softmax = _softmax.apply + + diff --git a/python/tutorials/01-vector-add.ipynb b/python/tutorials/01-vector-add.ipynb new file mode 100644 index 000000000..5532d9fc5 --- /dev/null +++ b/python/tutorials/01-vector-add.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "induced-zoning", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "median-malaysia", + "metadata": {}, + "source": [ + "In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n", + "* The basic syntax of the Triton programming language\n", + "* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n", + "* The best practices for validating and benchmarking custom ops against native reference implementations" + ] + }, + { + "cell_type": "markdown", + "id": "identical-conditions", + "metadata": {}, + "source": [ + "# Writing the Compute Kernel" + ] + }, + { + "cell_type": "markdown", + "id": "collectible-belle", + "metadata": {}, + "source": [ + "Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n", + "\n", + "\n", + "```c\n", + "__global__ void add(float* z, float* x, float* y, int N){\n", + " // The `get_program_id(i)` returns the i-th coordinate\n", + " // of the program in the overaching SPMD context\n", + " // (a.k.a launch grid). This is what allows us to process\n", + " // different chunks of data in parallel.\n", + " // For those similar with CUDA, `get_program_id({0,1,2})`\n", + " // is similar to blockIdx.{x,y,z}\n", + " int pid = get_program_id(0);\n", + " // In Triton, arrays are first-class citizen. In other words,\n", + " // they are primitives data-types and are -- contrary to C and\n", + " // CUDA -- not implemented as pointers to contiguous chunks of\n", + " // memory.\n", + " // In the few lines below, we create an array of `BLOCK` pointers\n", + " // whose memory values are, e.g.:\n", + " // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n", + " // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n", + " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", + " float* pz [BLOCK] = z + offset;\n", + " float* px [BLOCK] = x + offset;\n", + " float* py [BLOCK] = y + offset;\n", + " // Simple element-wise control-flow for load/store operations can\n", + " // be achieved using the the ternary operator `cond ? val_true : val_false`\n", + " // or the conditional dereferencing operator `*?(cond)ptr\n", + " // Here, we make sure that we do not access memory out-of-bounds when we\n", + " // write-back `z`\n", + " bool check[BLOCK] = offset < N;\n", + " *?(check)pz = *?(check)px + *?(check)py;\n", + "}\n", + "```\n", + "\n", + "The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)." + ] + }, + { + "cell_type": "markdown", + "id": "forbidden-wednesday", + "metadata": {}, + "source": [ + "# Writing the Torch bindings" + ] + }, + { + "cell_type": "markdown", + "id": "numerical-agency", + "metadata": {}, + "source": [ + "The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n", + "\n", + "To create a `triton.kernel`, you only need three things:\n", + "* `source: string`: the source-code of the kernel you want to create\n", + "* `device: torch.device`: the device you want to compile this code for\n", + "* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n", + "\n", + "Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "sporting-keyboard", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import triton\n", + "\n", + "# source-code for Triton compute kernel\n", + "# here we just copy-paste the above code without the extensive comments.\n", + "# you may prefer to store it in a .c file and load it from there instead.\n", + "_src = \"\"\"\n", + "__global__ void add(float* z, float* x, float* y, int N){\n", + " // program id\n", + " int pid = get_program_id(0);\n", + " // create arrays of pointers\n", + " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", + " float* pz[BLOCK] = z + offset;\n", + " float* px[BLOCK] = x + offset;\n", + " float* py[BLOCK] = y + offset;\n", + " // bounds checking\n", + " bool check[BLOCK] = offset < N;\n", + " // write-back\n", + " *?(check)pz = *?(check)px + *?(check)py;\n", + "}\n", + " \"\"\"\n", + "# This function returns a callable `triton.kernel` object\n", + "# created from the above source code.\n", + "# For portability, we maintain a cache of kernels for different `torch.device`\n", + "# We compile the kernel with -DBLOCK=1024\n", + "_kernels = dict()\n", + "def make_add_kernel(device):\n", + " if device not in _kernels:\n", + " defines = {'BLOCK': 1024}\n", + " _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n", + " return _kernels[device]\n", + "\n", + "# This is a standard torch custom autograd Function\n", + "# The only difference is that we can now use the above kernel\n", + "# in the `forward` and `backward` functions.`\n", + "class _add(torch.autograd.Function):\n", + " \n", + " @staticmethod\n", + " def forward(ctx, x, y):\n", + " # constraints of the op\n", + " assert x.dtype == torch.float32\n", + " # *allocate output*\n", + " z = torch.empty_like(x)\n", + " # *create launch grid*:\n", + " # this is a function which takes compilation parameters `opt`\n", + " # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n", + " # triton.cdiv is a shortcut for ceil division:\n", + " # triton.cdiv(a, b) = (a + b - 1) // b\n", + " N = z.shape[0]\n", + " grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n", + " # *launch kernel*:\n", + " # pointer to the data of torch tensors can be retrieved with\n", + " # the `.data_ptr()` method\n", + " kernel = make_add_kernel(z.device)\n", + " kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n", + " return z\n", + "# Just like we standard PyTorch ops\n", + "# We use the `.apply` method to create a \n", + "# callable object for our function\n", + "add = _add.apply" + ] + }, + { + "cell_type": "markdown", + "id": "separated-polyester", + "metadata": {}, + "source": [ + "At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!" + ] + }, + { + "cell_type": "markdown", + "id": "exclusive-salvation", + "metadata": {}, + "source": [ + "# Writing a Unit Test" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "supported-ribbon", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", + "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", + "The maximum difference between torch and triton is 0.0\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "x = torch.rand(98432, device='cuda')\n", + "y = torch.rand(98432, device='cuda')\n", + "za = x + y\n", + "zb = add(x, y)\n", + "print(za)\n", + "print(zb)\n", + "print(f'The maximum difference between torch and triton is '\n", + " f'{torch.max(torch.abs(za - zb))}')" + ] + }, + { + "cell_type": "markdown", + "id": "otherwise-canadian", + "metadata": {}, + "source": [ + "Seems to work!" + ] + }, + { + "cell_type": "markdown", + "id": "polished-australia", + "metadata": {}, + "source": [ + "# Writing a Benchmark" + ] + }, + { + "cell_type": "markdown", + "id": "historic-glass", + "metadata": {}, + "source": [ + "The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "strange-luxembourg", + "metadata": {}, + "outputs": [], + "source": [ + "# We now want to benchmark the performance of `add`\n", + "# Against that of PyTorch for increasing vector sizes\n", + "def do_bench(fn, warmup = 10, rep = 50):\n", + " start_event = torch.cuda.Event(enable_timing=True)\n", + " end_event = torch.cuda.Event(enable_timing=True)\n", + " ret = fn()\n", + " for i in range(warmup):\n", + " fn()\n", + " torch.cuda.synchronize()\n", + " start_event.record()\n", + " for i in range(rep):\n", + " fn()\n", + " end_event.record()\n", + " torch.cuda.synchronize()\n", + " time_ms = start_event.elapsed_time(end_event) / rep\n", + " return time_ms" + ] + }, + { + "cell_type": "markdown", + "id": "hairy-claim", + "metadata": {}, + "source": [ + "We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "pleasant-valley", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "131072 0.020 0.003\n", + "262144 0.019 0.004\n", + "524288 0.016 0.016\n", + "1048576 0.033 0.033\n", + "2097152 0.071 0.070\n", + "4194304 0.142 0.144\n", + "8388608 0.287 0.286\n", + "16777216 0.572 0.568\n", + "33554432 1.139 1.110\n" + ] + } + ], + "source": [ + "for N in [2**i for i in range(17, 26, 1)]:\n", + " x = torch.rand(N, device='cuda')\n", + " y = torch.rand(N, device='cuda')\n", + " triton_ms = do_bench(lambda: add(x, y))\n", + " torch_ms = do_bench(lambda: x + y)\n", + " # print the performance of triton and torch as well as the achieved bandwidth\n", + " print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')" + ] + }, + { + "cell_type": "markdown", + "id": "juvenile-supplement", + "metadata": {}, + "source": [ + "Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "agreed-backing", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt deleted file mode 100644 index 8c80ee070..000000000 --- a/tests/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -include_directories("${CMAKE_CURRENT_SOURCE_DIR}/common") -add_subdirectory(bench) -add_subdirectory(unit) diff --git a/tests/bench/conv.cc b/tests/bench/conv.cc deleted file mode 100644 index 67aa090c4..000000000 --- a/tests/bench/conv.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "conv.h" - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple config_t; - std::vector configs = { -// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1}, -// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1}, - -// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1}, -// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1}, - {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1}, -// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1}, -// {1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1}, -// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1} - - - - - }; - int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w; - for(const auto& c: configs){ - std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c; - std::cout << "// " << c ; - for(auto perf: bench_conv(context, stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w)) - std::cout << ", " << perf << std::flush; - std::cout << std::endl; - } -} diff --git a/tests/bench/copy.cc b/tests/bench/copy.cc deleted file mode 100644 index 09869f5cf..000000000 --- a/tests/bench/copy.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include -#include -#include "copy.h" -#include "triton/driver/backend.h" - - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple, std::vector, std::vector> config_t; - std::vector configs = { - {{4096*4096}, {0}, {0}}, - {{4096, 4096}, {0, 1}, {0, 1}}, - {{4096, 4096}, {0, 1}, {1, 0}}, - {{4096, 4096}, {1, 0}, {0, 1}}, - {{4096, 4096}, {1, 0}, {1, 0}}, - {{256, 256, 256}, {0, 1, 2}, {0, 1, 2}}, - {{256, 256, 256}, {0, 1, 2}, {0, 2, 1}}, - {{256, 256, 256}, {1, 0, 2}, {1, 2, 0}}, - {{256, 256, 256}, {1, 2, 0}, {1, 0, 2}} -// {{256, 256, 256}, {2, 0, 1}, {0, 1, 2}}, -// {{256, 256, 256}, {2, 1, 0}, {0, 2, 1}} - }; - // does the work - std::vector shape; - std::vector ord_x, ord_y; - for(const auto& c: configs){ - std::tie(shape, ord_x, ord_y) = c; - std::cout << "// " << c << std::flush; - for(auto perf: bench_copy_nd(context, stream, HALF, shape, ord_x, ord_y)) - std::cout << ", " << perf << std::flush; - std::cout << std::endl; - } -} diff --git a/tests/bench/copy1d.cc b/tests/bench/copy1d.cc deleted file mode 100644 index 51afbacd6..000000000 --- a/tests/bench/copy1d.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include "src/copy.h" -#include "util.h" -#include "cuda/cublas.h" - - -std::vector do_bench(drv::stream* stream, int32_t N){ - typedef float NumericT; - std::string ty = "float"; - size_t dt_nbytes = sizeof(NumericT); - drv::context* context = stream->context(); - // create inputs - auto dx = std::unique_ptr(drv::buffer::create(context, N*dt_nbytes)); - auto dy = std::unique_ptr(drv::buffer::create(context, N*dt_nbytes)); - // create options - rt::function::options_space_t opt; - opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"TN", {"128"}}); - opt.num_warps = {1, 2, 4, 8}; - // create function - rt::function function(src::copy1d, opt); - // benchmark available libraries - std::vector result; - auto gbps = [&](double ns) { return 2*N*dt_nbytes / (ns * 1e-9) * 1e-9; }; - // triton - double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, N}, grid1d(N), stream);}, stream); - result.push_back(gbps(triton_ns)); - // done - return result; -} - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); - // shapes to benchmark - typedef std::tuple config_t; - std::vector configs = { 1024*1024*32 }; - int N; - for(const auto& c: configs){ - std::tie(N) = c; - std::cout << "// " << c << std::flush; - for(auto perf: do_bench(stream, N)) - std::cout << ", " << perf << std::flush; - std::cout << std::endl; - } -} diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc deleted file mode 100644 index 483fde4df..000000000 --- a/tests/bench/dot.cc +++ /dev/null @@ -1,72 +0,0 @@ -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "dot.h" - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple, bool, bool, int, int, int> config_t; - std::vector configs; - for(auto ord: std::vector>{{1, 0}}) - for(auto x: std::vector>{{false, true}, {false, false}, {true, false}, {true, true}}){ - std::vector tmp = { -// config_t{ord, x[0], x[1], 128, 128, 128}, -// config_t{ord, x[0], x[1], 256, 256, 256}, -// config_t{ord, x[0], x[1], 384, 384, 384}, -// config_t{ord, x[0], x[1], 512, 512, 512}, -// config_t{ord, x[0], x[1], 768, 768, 768}, - // config_t{ord, x[0], x[1], 1024, 1024, 1024}, -// config_t{ord, x[0], x[1], 1280, 1280, 1280}, -// config_t{ord, x[0], x[1], 1536, 1536, 1536}, -// config_t{ord, x[0], x[1], 2048, 2048, 2048}, - config_t{ord, x[0], x[1], 8192, 8192, 8192}, - -// config_t{ord, x[0], x[1], 256, 16, 256}, -// config_t{ord, x[0], x[1], 512, 16, 512}, -// config_t{ord, x[0], x[1], 768, 16, 768}, -// config_t{ord, x[0], x[1], 1024, 16, 1024}, -// config_t{ord, x[0], x[1], 1280, 16, 1280}, -// config_t{ord, x[0], x[1], 1536, 16, 1536}, -// config_t{ord, x[0], x[1], 2048, 16, 2048}, -// config_t{ord, x[0], x[1], 3072, 16, 3072}, -// config_t{ord, x[0], x[1], 4096, 16, 4096}, -// config_t{ord, x[0], x[1], 5120, 16, 5120}, -// config_t{ord, x[0], x[1], 6144, 16, 6144}, -// config_t{ord, x[0], x[1], 7168, 16, 7168}, - -// config_t{ord, x[0], x[1], 64, 64, 4096}, -// config_t{ord, x[0], x[1], 64, 64, 8192}, -// config_t{ord, x[0], x[1], 64, 64, 16384}, -// config_t{ord, x[0], x[1], 64, 64, 32768}, -// config_t{ord, x[0], x[1], 64, 64, 65536}, -// config_t{ord, x[0], x[1], 64, 64, 131072} - -// config_t{ord, x[0], x[1], 127008, 768, 576}, -// config_t{ord, x[0], x[1], 8192, 8192, 8192} -// config_t{ord, x[0], x[1], 16, 2048, 2048}, -// config_t{ord, x[0], x[1], 32, 2048, 2048}, -// config_t{ord, x[0], x[1], 64, 2048, 2048}, -// config_t{ord, x[0], x[1], 128, 2048, 2048}, -// config_t{ord, x[0], x[1], 7000, 2048, 2048}, -// config_t{ord, x[0], x[1], 16, 4096, 4096}, -// config_t{ord, x[0], x[1], 32, 4096, 4096}, -// config_t{ord, x[0], x[1], 64, 4096, 4096}, -// config_t{ord, x[0], x[1], 128, 4096, 4096}, -// config_t{ord, x[0], x[1], 7000, 4096, 4096} - }; - configs.insert(configs.end(), tmp.begin(), tmp.end()); - } - // does the work - std::vector ord; - bool AT, BT; - int32_t M, N, K; - for(const auto& c: configs){ - std::tie(ord, AT, BT, M, N, K) = c; - std::cout << "// " << c ; - for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord)) - std::cout << ", " << perf << std::flush; - std::cout << std::endl; - } -} diff --git a/tests/common/conv.h b/tests/common/conv.h deleted file mode 100644 index d2abdd718..000000000 --- a/tests/common/conv.h +++ /dev/null @@ -1,140 +0,0 @@ -#include -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include "src/conv.h" -#include "cuda/cublas.h" -#include "util.h" - -enum run_mode_t { - BENCH, - TEST -}; - -enum dtype_t { - FLOAT, - HALF, - DOUBLE -}; - -template -struct to_string; - -template<> struct to_string{ - static constexpr const char* value = "half"; -}; - -template<> struct to_string{ - static constexpr const char* value = "float"; -}; - -template<> struct to_string{ - static constexpr const char* value = "double"; -}; - -template -void triton_conv(drv::context* context, drv::stream* stream, - int Z, int CI, int H, int W, int CO, int R, int S, - int pad_h, int pad_w, int stride_h, int stride_w, - run_mode_t mode, std::vector& bench, bool &test){ - std::string ty = to_string::value; - size_t dt_nbytes = sizeof(T); - drv::device* device = context->device(); - - int P = (H + 2*pad_h - R)/stride_h + 1; - int Q = (W + 2*pad_w - S)/stride_w + 1; - - // inputs - auto dc = std::shared_ptr(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes)); - auto da = std::shared_ptr(drv::buffer::create(context, Z*CI*H*W*dt_nbytes)); - auto db = std::shared_ptr(drv::buffer::create(context, CI*R*S*CO*dt_nbytes)); - auto ddelta = std::shared_ptr(drv::buffer::create(context, CI*R*S*4)); - auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); - ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); - - std::vector hdelta(CI*R*S); - int TK = 16; - for(int i = 0; i < hdelta.size(); i++){ - int s = i % S; - int cr = i / S; - int r = cr % R; - int c = cr / R; - int nexti = i + TK; - int nexts = nexti % S; - int nextcr = nexti / S; - int nextr = nextcr % R; - int nextc = nextcr / R; - hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s); - } - stream->write(&*ddelta, true, 0, hdelta); - - // macros - rt::options_space_t opt; - opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"TM", {"64", "128"}}); - opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {std::to_string(TK)}}); - opt.defines.push_back({"TZ", {"1"}}); - opt.defines.push_back({"RR", {std::to_string(R)}}); - opt.defines.push_back({"SS", {std::to_string(S)}}); - opt.defines.push_back({"PP", {std::to_string(P)}}); - opt.defines.push_back({"QQ", {std::to_string(Q)}}); - opt.defines.push_back({"HH", {std::to_string(H)}}); - opt.defines.push_back({"WW", {std::to_string(W)}}); - opt.num_warps = {4}; - // arguments - std::stringstream oss; - rt::add_arg(oss, *da->cu()); - rt::add_arg(oss, *db->cu()); - rt::add_arg(oss, *dc->cu()); - rt::add_arg(oss, (float)1); - rt::add_arg(oss, Z*P*Q); - rt::add_arg(oss, CO); - rt::add_arg(oss, CI*R*S); - rt::add_arg(oss, pad_h); - rt::add_arg(oss, pad_w); - rt::add_arg(oss, stride_h); - rt::add_arg(oss, stride_w); - rt::add_arg(oss, *ddelta->cu()); - rt::add_arg(oss, W*H*CI); - rt::add_arg(oss, W*H); - rt::add_arg(oss, W); - rt::add_arg(oss, 1); - rt::add_arg(oss, CO*S*R); - rt::add_arg(oss, CO*S); - rt::add_arg(oss, CO); - rt::add_arg(oss, 1); - rt::add_arg(oss, Q*P*CO); - rt::add_arg(oss, Q*P); - rt::add_arg(oss, Q); - rt::add_arg(oss, 1); - // kernels - rt::function function(src::conv, opt); - auto grid = [Z,P,Q,CO](const rt::options_t& x) { - return rt::grid_t{ceil(Z*P*Q, x.D("TM")), - ceil(CO , x.D("TN")), - (size_t)x.D("TZ")}; - }; - auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); - bench.push_back(tflops(triton_ns)); -} - -std::vector bench_conv(drv::context* context, drv::stream* stream, dtype_t dtype, - int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S, - int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) { - std::vector bench; - bool test; - switch(dtype){ - case HALF: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; - case FLOAT: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; - case DOUBLE: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; - default: break; - } - return bench; -} diff --git a/tests/common/copy.h b/tests/common/copy.h deleted file mode 100644 index 4925c4c84..000000000 --- a/tests/common/copy.h +++ /dev/null @@ -1,186 +0,0 @@ -#include "src/copy.h" -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "util.h" - -int32_t off(const std::vector& idx, const std::vector& strides) { - int32_t res = 0; - for(size_t d = 0; d < idx.size(); d++) - res += idx[d] * strides[d]; - return res; -} - -enum run_mode_t { - BENCH, - TEST -}; - -enum dtype_t { - FLOAT, - HALF, - DOUBLE -}; - - -template -void cc_copy_nd(const std::vector& x, std::vector& y, - const std::vector& shape, - const std::vector& x_order, const std::vector& y_order) { - size_t rank = shape.size(); - // strides for x - std::vector x_strides(shape.size()); - for(size_t d = 0; d < rank; d++) - x_strides[x_order[d]] = (d == 0) ? 1 : (x_strides[x_order[d-1]] * shape[x_order[d-1]]); - // strides for y - std::vector y_strides(shape.size()); - for(size_t d = 0; d < rank; d++) - y_strides[y_order[d]] = (d == 0) ? 1 : (y_strides[y_order[d-1]] * shape[y_order[d-1]]); - // copy 1d - if(rank == 1) - for(int32_t i = 0; i < shape[0]; i++) - y[off({i}, y_strides)] = x[off({i}, x_strides)]; - // copy 2d - if(rank == 2) - for(int32_t i = 0; i < shape[0]; i++) - for(int32_t j = 0; j < shape[1]; j++) - y[off({i, j}, y_strides)] = x[off({i, j}, x_strides)]; - // copy 3d - if(rank == 3) - for(int32_t i = 0; i < shape[0]; i++) - for(int32_t j = 0; j < shape[1]; j++) - for(int32_t k = 0; k < shape[2]; k++) - y[off({i, j, k}, y_strides)] = x[off({i, j, k}, x_strides)]; -} - -template -struct to_string; - -template<> struct to_string{ - static constexpr const char* value = "half"; -}; - -template<> struct to_string{ - static constexpr const char* value = "float"; -}; - -template<> struct to_string{ - static constexpr const char* value = "double"; -}; - -template -void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vector& shape, - const std::vector& x_order, const std::vector& y_order, - std::vector> TS, - run_mode_t mode, std::vector& bench, bool &test) { - std::string ty = to_string::value; - size_t dtsize = sizeof(T); - drv::device* device = context->device(); - - // rank - size_t rank = shape.size(); - // size - size_t size = 1; - for(int32_t d: shape) - size *= d; - std::vector shapename = {"S0", "S1", "S2"}; - // strides for x - std::vector x_strides = {"1"}; - for(size_t d = 0; d < rank - 1; d++) - x_strides.push_back(x_strides[d] + " * " + shapename[x_order[d]]); - // strides for y - std::vector y_strides = {"1"}; - for(size_t d = 0; d < rank - 1; d++) - y_strides.push_back(y_strides[d] + " * " + shapename[y_order[d]]); - - // create inputs - auto dx = std::unique_ptr(drv::buffer::create(context, size*dtsize)); - auto dy = std::unique_ptr(drv::buffer::create(context, size*dtsize)); - // create options - rt::options_space_t opt; - - - // macros - opt.defines.push_back({"TYPE", {ty}}); - for(size_t d = 0; d < rank; d++) - opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}}); - for(size_t d = 0; d < rank; d++) - opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}}); - if(TS.empty()) - TS = tile_nd(rank); - for(size_t d = 0; d < rank; d++) - opt.defines.push_back({"TS" + std::to_string(d), TS[d]}); - opt.num_warps = {4}; - - // kernel - rt::function function(src::copy_nd[rank - 1], opt); - - std::stringstream oss; - rt::add_arg(oss, *dx->cu()); - rt::add_arg(oss, *dy->cu()); - rt::add_arg(oss, (uint32_t)shape[0]); - if(shape.size() > 1) rt::add_arg(oss, (uint32_t)shape[1]); - if(shape.size() > 2) rt::add_arg(oss, (uint32_t)shape[2]); - - std::vector ts = {"TS0", "TS1", "TS2"}; - auto grid = grid_nd(shape, ts); - - // metrics - if(mode == BENCH){ - auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); - bench.push_back(gbps(triton_ns)); - } - - // test triton - if(mode == TEST){ - std::vector hx(size); - std::vector hy(size); - std::vector ry(size); - for(size_t i = 0; i < hx.size(); i++) - hx[i] = static_cast((float)rand()/RAND_MAX); - stream->write(&*dx, true, 0, hx); - function((void**)oss.str().data(), oss.str().size(), grid, stream, device); - stream->synchronize(); - stream->read(&*dy, true, 0, hy); - cc_copy_nd(hx, ry, shape, x_order, y_order); - test = testing::diff(hy, ry); - } -} - -std::vector bench_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector& shape, - const std::vector& x_order, const std::vector& y_order) { - std::vector bench; - bool test; - switch(dtype){ - case HALF: - triton_copy_nd(context, stream, shape, x_order, y_order, {}, BENCH, bench, test); - break; - case FLOAT: - triton_copy_nd(context, stream, shape, x_order, y_order, {}, BENCH, bench, test); - break; - default: break; - } - return bench; -} - -bool test_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector& shape, - const std::vector& TS, - const std::vector& x_order, const std::vector& y_order) { - std::vector bench; - bool test; - std::vector> TSS; - for(int32_t d: TS) - TSS.push_back({std::to_string(d)}); - switch(dtype){ - case HALF: - triton_copy_nd(context, stream, shape, x_order, y_order, TSS, TEST, bench, test); - break; - case FLOAT: - triton_copy_nd(context, stream, shape, x_order, y_order, TSS, TEST, bench, test); - break; - default: break; - } - return test; -} diff --git a/tests/common/cuda/cublas.h b/tests/common/cuda/cublas.h deleted file mode 100644 index aee399f57..000000000 --- a/tests/common/cuda/cublas.h +++ /dev/null @@ -1,218 +0,0 @@ -/* Copyright 2019 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, -* subject to the following conditions: -* -* The above copyright notice and this permission notice shall be -* included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -*/ - -#include -#include -#include -#include "forward.h" -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/driver/context.h" -#include "triton/driver/error.h" -#include "triton/tools/bench.hpp" - - -class cublas { -private: - template - struct return_type; - - template - struct return_type - { typedef R type; }; - - typedef bool (*f_init_t)(); - - template - static typename return_type::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args) - { - initializer(); - if(cache == nullptr){ - cache = dlsym(lib_h, name); - if(cache == 0) - throw std::runtime_error("dlsym unable to load function"); - } - FunPtrT fptr; - *reinterpret_cast(&fptr) = cache; - typename return_type::type res = (*fptr)(args...); - triton::driver::check(res); - return res; - } - -public: - static bool cublasinit(); - static cublasStatus_t cublasSetMathMode(cublasHandle_t h, cublasMath_t m); - static cublasStatus_t cublasCreate_v2(cublasHandle_t* h); - static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId); - static cublasStatus_t cublasSetStream_v2(cublasHandle_t h, cudaStream_t streamId); - static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const void *alpha, const void *A, cudaDataType Atype, int lda, - const void *B, cudaDataType Btype, int ldb, const void *beta, - void *C, cudaDataType Ctype, int ldc, - cudaDataType computeType, cublasGemmAlgo_t algo); - -private: - static void* so_; - static void* cublasGetStream_v2_; - static void* cublasSetStream_v2_; - static void* cublasCreate_v2_; - static void* cublasGemmEx_; - static void* cublasSetMathMode_; -}; - -void* cublas::so_; -void* cublas::cublasGetStream_v2_; -void* cublas::cublasSetStream_v2_; -void* cublas::cublasCreate_v2_; -void* cublas::cublasGemmEx_; -void* cublas::cublasSetMathMode_; - - -bool cublas::cublasinit() { - if(so_==nullptr) - so_ = dlopen("libcublas.so", RTLD_LAZY); - return so_ != nullptr; -} - -cublasStatus_t cublas::cublasGetStream_v2(cublasHandle_t h, cudaStream_t *a) -{ return f_impl(so_, cublasGetStream_v2, cublasGetStream_v2_, "cublasGetStream_v2", h, a); } -cublasStatus_t cublas::cublasSetStream_v2(cublasHandle_t h, cudaStream_t a) -{ return f_impl(so_, cublasSetStream_v2, cublasSetStream_v2_, "cublasSetStream_v2", h, a); } -cublasStatus_t cublas::cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const void *alpha, const void *A, cudaDataType Atype, int lda, - const void *B, cudaDataType Btype, int ldb, const void *beta, - void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo) { - return f_impl(so_, cublasGemmEx, cublasGemmEx_, "cublasGemmEx", handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); -} -cublasStatus_t cublas::cublasCreate_v2(cublasHandle_t *h) { - return f_impl(so_, cublasCreate_v2, cublasCreate_v2_, "cublasCreate_v2", h); -} -cublasStatus_t cublas::cublasSetMathMode(cublasHandle_t h, cublasMath_t m) { - return f_impl(so_, cublasSetMathMode, cublasSetMathMode_, "cublasSetMathMode", h, m); -} - - - -inline cublasGemmAlgo_t cublasGemmFastest( - triton::driver::stream* stream, - cublasHandle_t handle, cudaDataType cudt, - cublasOperation_t AT, cublasOperation_t BT, - int32_t M, int32_t N, int32_t K, - void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, - void* beta, CUdeviceptr C, int32_t ldc) { - - // initialize list of cublas algorithms - static std::vector algorithms; - if(algorithms.empty()) { - // non-tensor ops - for(int i = -1; i < 24; i++) - algorithms.push_back((cublasGemmAlgo_t)i); - // tensor ops - for(int i = 99; i < 116; i++) - algorithms.push_back((cublasGemmAlgo_t)i); - } - - // cache to avoid re-benchmarking - typedef std::tuple key_t; - static std::map cache; - key_t key(cudt, AT, BT, M, N, K); - // benchmark algorithms if necessary - if(cache.find(key) == cache.end()){ - std::vector times; - for(cublasGemmAlgo_t a: algorithms) { - cublasStatus_t status; - double nanosec = triton::tools::bench([&](){ status = cublas::cublasGemmEx(handle, AT, BT, - M, N, K, - alpha, (const void*)A, cudt, lda, - (const void*)B, cudt, ldb, - beta, (void*)C, cudt, ldc, CUDA_R_32F, - a); }, stream); - if(status != CUBLAS_STATUS_SUCCESS) - nanosec = INFINITY; - } - size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); - assert(times[argmin] != INFINITY); - cache.insert({key, algorithms[argmin]}); - } - - // return best algorithm - return cache.at(key); -} - - - - -/* Get cuBLAS handle */ -inline cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { - static std::map cache; - CUstream key = *stream->cu(); - - // create handle if necessary - if(cache.find(key) == cache.end()) { - cublasHandle_t handle; - if(cublas::cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) - throw std::runtime_error("Error: could not create cuBLAS handle"); - cublas::cublasSetStream_v2(handle, key); - cache.insert({key, handle}); - } - - // return handle for the stream - return cache.at(key); -} - - - -/* Simplified API for default GEMM */ -inline void cublasGemm(cublasDataType_t dtype, - triton::driver::stream* stream, - bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - void* alpha, triton::driver::buffer* A, int32_t lda, - triton::driver::buffer* B, int32_t ldb, - void* beta, triton::driver::buffer* C, int32_t ldc, - cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { - // get handle - static cublasHandle_t handle = cublasGetHandle(stream); - // set math mode - if(dtype == CUDA_R_16F) - cublas::cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); - // cuda types - static const std::map cu_op = { - {false, CUBLAS_OP_N}, - {true, CUBLAS_OP_T} - }; - cublasOperation_t opa = cu_op.at(AT); - cublasOperation_t opb = cu_op.at(BT); - // benchmark fastest - if(fastest) - *fastest = cublasGemmFastest(stream, handle, dtype, opa, opb, M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); - else { - // execute supplied algo - cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K, - alpha, (const void*)*A->cu(), dtype, lda, - (const void*)*B->cu(), dtype, ldb, - beta, (void*)*C->cu(), dtype, ldc, CUDA_R_32F, algo); - } -} diff --git a/tests/common/cuda/forward.h b/tests/common/cuda/forward.h deleted file mode 100644 index bd32adec6..000000000 --- a/tests/common/cuda/forward.h +++ /dev/null @@ -1,105 +0,0 @@ -#ifndef _COMMON_CUDA_FORWARDS_H_ -#define _COMMON_CUDA_FORwARDS_H_ - -struct cublasContext; -typedef struct cublasContext *cublasHandle_t; -struct CUstream_st; -typedef struct CUstream_st *cudaStream_t; - -/* CUBLAS status type returns */ -typedef enum{ - CUBLAS_STATUS_SUCCESS =0, - CUBLAS_STATUS_NOT_INITIALIZED =1, - CUBLAS_STATUS_ALLOC_FAILED =3, - CUBLAS_STATUS_INVALID_VALUE =7, - CUBLAS_STATUS_ARCH_MISMATCH =8, - CUBLAS_STATUS_MAPPING_ERROR =11, - CUBLAS_STATUS_EXECUTION_FAILED=13, - CUBLAS_STATUS_INTERNAL_ERROR =14, - CUBLAS_STATUS_NOT_SUPPORTED =15, - CUBLAS_STATUS_LICENSE_ERROR =16 -} cublasStatus_t; - -/*For different GEMM algorithm */ -typedef enum { - CUBLAS_GEMM_DFALT = -1, - CUBLAS_GEMM_DEFAULT = -1, - CUBLAS_GEMM_ALGO0 = 0, // maxwell_sgemm_32x128_nt - CUBLAS_GEMM_ALGO1 = 1, // maxwell_sgemm_64x64_nt - CUBLAS_GEMM_ALGO2 = 2, // maxwell_sgemm_128x32_nt - CUBLAS_GEMM_ALGO3 = 3, // maxwell_sgemm_128x64_nt - CUBLAS_GEMM_ALGO4 = 4, // maxwell_sgemm_128x128_nt - CUBLAS_GEMM_ALGO5 = 5, - CUBLAS_GEMM_ALGO6 = 6, - CUBLAS_GEMM_ALGO7 = 7, - CUBLAS_GEMM_ALGO8 = 8, - CUBLAS_GEMM_ALGO9 = 9, - CUBLAS_GEMM_ALGO10 = 10, - CUBLAS_GEMM_ALGO11 = 11, - CUBLAS_GEMM_ALGO12 = 12, - CUBLAS_GEMM_ALGO13 = 13, - CUBLAS_GEMM_ALGO14 = 14, - CUBLAS_GEMM_ALGO15 = 15, - CUBLAS_GEMM_ALGO16 = 16, - CUBLAS_GEMM_ALGO17 = 17, - CUBLAS_GEMM_ALGO18 = 18, //sliced 32x32 - CUBLAS_GEMM_ALGO19 = 19, //sliced 64x32 - CUBLAS_GEMM_ALGO20 = 20, //sliced 128x32 - CUBLAS_GEMM_ALGO21 = 21, //sliced 32x32 -splitK - CUBLAS_GEMM_ALGO22 = 22, //sliced 64x32 -splitK - CUBLAS_GEMM_ALGO23 = 23, //sliced 128x32 -splitK - CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99, - CUBLAS_GEMM_DFALT_TENSOR_OP = 99, - CUBLAS_GEMM_ALGO0_TENSOR_OP = 100, - CUBLAS_GEMM_ALGO1_TENSOR_OP = 101, - CUBLAS_GEMM_ALGO2_TENSOR_OP = 102, - CUBLAS_GEMM_ALGO3_TENSOR_OP = 103, - CUBLAS_GEMM_ALGO4_TENSOR_OP = 104, - CUBLAS_GEMM_ALGO5_TENSOR_OP = 105, - CUBLAS_GEMM_ALGO6_TENSOR_OP = 106, - CUBLAS_GEMM_ALGO7_TENSOR_OP = 107, - CUBLAS_GEMM_ALGO8_TENSOR_OP = 108, - CUBLAS_GEMM_ALGO9_TENSOR_OP = 109, - CUBLAS_GEMM_ALGO10_TENSOR_OP = 110, - CUBLAS_GEMM_ALGO11_TENSOR_OP = 111, - CUBLAS_GEMM_ALGO12_TENSOR_OP = 112, - CUBLAS_GEMM_ALGO13_TENSOR_OP = 113, - CUBLAS_GEMM_ALGO14_TENSOR_OP = 114, - CUBLAS_GEMM_ALGO15_TENSOR_OP = 115 -} cublasGemmAlgo_t; - -typedef enum cudaDataType_t -{ - CUDA_R_16F= 2, /* real as a half */ - CUDA_C_16F= 6, /* complex as a pair of half numbers */ - CUDA_R_32F= 0, /* real as a float */ - CUDA_C_32F= 4, /* complex as a pair of float numbers */ - CUDA_R_64F= 1, /* real as a double */ - CUDA_C_64F= 5, /* complex as a pair of double numbers */ - CUDA_R_8I = 3, /* real as a signed char */ - CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ - CUDA_R_8U = 8, /* real as a unsigned char */ - CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ - CUDA_R_32I= 10, /* real as a signed int */ - CUDA_C_32I= 11, /* complex as a pair of signed int numbers */ - CUDA_R_32U= 12, /* real as a unsigned int */ - CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */ -} cudaDataType; - -typedef cudaDataType cublasDataType_t; - -typedef enum { - CUBLAS_OP_N=0, - CUBLAS_OP_T=1, - CUBLAS_OP_C=2, - CUBLAS_OP_HERMITAN=2, /* synonym if CUBLAS_OP_C */ - CUBLAS_OP_CONJG=3 /* conjugate */ -} cublasOperation_t; - -/*Enum for default math mode/tensor operation*/ -typedef enum { - CUBLAS_DEFAULT_MATH = 0, - CUBLAS_TENSOR_OP_MATH = 1 -} cublasMath_t; - -#endif diff --git a/tests/common/dot.h b/tests/common/dot.h deleted file mode 100644 index 5b8d4c15d..000000000 --- a/tests/common/dot.h +++ /dev/null @@ -1,215 +0,0 @@ -#include -#include -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include "src/dot.h" -#include "cuda/cublas.h" -#include "util.h" - -template -static void cc_dot(std::vector &c, const std::vector &a, const std::vector &b, - size_t M, size_t N, size_t K){ - for(size_t m = 0; m < M; m++) - for(size_t n = 0; n < N; n++){ - float acc = 0; - for(size_t k = 0; k < K; k++) - acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]); - c[m*N + n] = static_cast(acc); - } -} - -template -void cc_dot(bool AT_, bool BT_, size_t M, size_t N, size_t K, - std::vector &c, const std::vector &a, const std::vector &b) { - if(AT_ && BT_) - cc_dot(c, a, b, M, N, K); - else if(AT_ && !BT_) - cc_dot(c, a, b, M, N, K); - else if(!AT_ && BT_) - cc_dot(c, a, b, M, N, K); - else - cc_dot(c, a, b, M, N, K); -} - -enum run_mode_t { - BENCH, - TEST -}; - -enum dtype_t { - FLOAT, - HALF, - DOUBLE -}; - -template -struct to_string; - -template<> struct to_string{ - static constexpr const char* value = "half"; -}; - -template<> struct to_string{ - static constexpr const char* value = "float"; -}; - -template<> struct to_string{ - static constexpr const char* value = "double"; -}; - -template -void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - int32_t TM, int32_t TN, int32_t TK, int32_t nwarp, - const std::vector& a_order, const std::vector& b_order, - run_mode_t mode, std::vector& bench, bool &test){ - std::string ty = to_string::value; - size_t dt_nbytes = sizeof(T); - drv::device* device = context->device(); - int32_t lda = (AT ^ a_order[0]==1) ? K : M; - int32_t ldb = (BT ^ b_order[0]==1) ? N : K; - int32_t ldc = N; - std::vector sa = { "1", "lda" }; - std::vector sb = { "1", "ldb" }; - - // inputs - auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); - auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); - auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); - auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); -// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); - - // macros - rt::options_space_t opts; - // A access patterns - opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); - opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); - // B access patterns - opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); - opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); - // data-type - opts.defines.push_back({"TYPE", {ty}}); - // tile sizes - if(mode == TEST) { - opts.defines.push_back({"TM", {std::to_string(TM)}}); - opts.defines.push_back({"TN", {std::to_string(TN)}}); - opts.defines.push_back({"TK", {std::to_string(TK)}}); - opts.defines.push_back({"TZ", {"1"}}); - opts.num_warps = {nwarp}; - } - if(mode == BENCH) { - opts.defines.push_back({"TM", {"128"}}); - opts.defines.push_back({"TN", {"128"}}); - opts.defines.push_back({"TK", {"32"}}); - opts.defines.push_back({"TZ", {"1"}}); - opts.num_warps = {4}; - } - - // arguments - std::stringstream oss; - rt::add_arg(oss, *da->cu()); - rt::add_arg(oss, *db->cu()); - rt::add_arg(oss, *dc->cu()); - rt::add_arg(oss, (float)1); - rt::add_arg(oss, M); - rt::add_arg(oss, N); - rt::add_arg(oss, K); - rt::add_arg(oss, lda); - rt::add_arg(oss, ldb); - rt::add_arg(oss, ldc); - rt::add_arg(oss, *dlocks->cu()); - // kernel - rt::function function(src::dot, opts); - // grid - auto grid = [M, N](const rt::options_t& x) { - return rt::grid_t{ceil(M, x.D("TM"))* - ceil(N, x.D("TN")), - (size_t)x.D("TZ")}; - }; - - // metrics - if(mode == BENCH){ - auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); - bench.push_back(tflops(triton_ns)); - - // cublas - if(cublas::cublasinit()){ - T alpha(static_cast(1)); - T beta(static_cast(0)); - cublasGemmAlgo_t fastest; -// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); - double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, !AT, !BT, M, N, K, - &alpha, &*da, lda, &*db, ldb, &beta, &*dc, - ldc); }, stream); - bench.push_back(tflops(cublas_ms)); - } - } - -// rt::options_t opt; -// for(auto &x: opts.defines) -// opt.defines[x.first] = x.second[0]; -// opt.num_warps = 1; -// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl; - - // test triton - if(mode == TEST){ - srand(0); - // initialize buffers - std::vector hc(M*N); - std::vector ha(M*K); - std::vector hb(K*N); - for(size_t i = 0; i < ha.size(); i++) - ha[i] = (float)rand()/RAND_MAX; - for(size_t i = 0; i < hb.size(); i++) - hb[i] = (float)rand()/RAND_MAX; - // copy buffer - stream->write(&*da, true, 0, ha); - stream->write(&*db, true, 0, hb); - // run kernel - function((void**)oss.str().data(), oss.str().size(), grid, stream, device); - // write back - stream->synchronize(); - // compare with CPU - stream->read(&*dc, true, 0, hc); - std::vector rc(hc.size()); - cc_dot(AT, BT, M, N, K, rc, ha, hb); - test = testing::diff(hc, rc); - } -} - -std::vector bench_dot(drv::context* context, drv::stream* stream, - dtype_t dtype, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - const std::vector& a_order, const std::vector& b_order) { - std::vector bench; - bool test; - switch(dtype){ - case HALF: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; - case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; - case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; - default: break; - } - return bench; -} -bool test_dot(drv::context* context, drv::stream* stream, - dtype_t dtype, bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - const std::vector& a_order, const std::vector& b_order, - int32_t TM, int32_t TN, int32_t TK, size_t nwarp) { - std::vector bench; - bool test = false; - switch(dtype){ - case HALF: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; - case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; - case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; - default: break; - } - return test; -} diff --git a/tests/common/reduce.h b/tests/common/reduce.h deleted file mode 100644 index 3c0b79fc9..000000000 --- a/tests/common/reduce.h +++ /dev/null @@ -1,173 +0,0 @@ -#include -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include "src/reduce.h" -#include "util.h" - -namespace drv = triton::driver; -namespace rt = triton::runtime; - -template -void cc_reduce_nd(std::vector &y, const std::vector &x, reduce_op_t op, size_t axis, const std::vector& shapes) { - assert(axis <= shapes.size() - 1); - // remove shape at index axis to get outer dimensions - std::vector outer = shapes; - outer.erase(outer.begin() + axis); - if(outer.empty()) - outer.push_back(1); - // retrieve shape at index axis to get inner dimension - int inner = shapes[axis]; - // accumualtion function - auto acc = get_accumulator(op); - // iterate over outer dimensions - _loop_nest(outer, [&](const std::vector& y_idx) { - T ret = 0; - auto x_idx = y_idx; - x_idx.insert(x_idx.begin() + axis, 0); - // accumulate over inner dimensions - for(int z = 0; z < inner; z++){ - x_idx[axis] = z; - ret = acc(ret, x[offset(x_idx, shapes)]); - } - y[offset(y_idx, outer)] = ret; - }); -} - -enum run_mode_t { - BENCH, - TEST -}; - -void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vector& shape_x, - int axis, reduce_op_t op, - const std::vector& x_order, const std::vector& y_order, - std::vector> TS, - run_mode_t mode, std::vector& bench, bool &test) { - typedef float NumericT; - std::string ty = "float"; - size_t dtsize = sizeof(NumericT); - drv::device* device = context->device(); - - - - // shape - std::vector shape_y = shape_x; - shape_y.erase(shape_y.begin() + axis); - - // rank - int rank_x = shape_x.size(); - int rank_y = shape_y.size(); - - // size - size_t size_x = 1; - for(int32_t d: shape_x) - size_x *= d; - size_t size_y = 1; - for(int32_t d: shape_y) - size_y *= d; - - // strides for x - std::vector x_shapename = {"S0", "S1", "S2"}; - std::vector x_strides = {"1"}; - for(int d = 0; d < rank_x - 1; d++) - x_strides.push_back(x_strides[d] + " * " + x_shapename[x_order[d]]); - - // strides for y - std::vector y_shapename = x_shapename; - y_shapename.erase(y_shapename.begin() + axis); - std::vector y_strides = {"1"}; - for(int d = 0; d < rank_y - 1; d++) - y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]); - - // options - rt::options_space_t opts; - opts.defines.push_back({"TYPE", {ty}}); - for(int d = 0; d < rank_x; d++) - opts.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}}); - for(int d = 0; d < rank_y; d++) - opts.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}}); - if(TS.empty()) - TS = tile_nd(rank_x); - for(int d = 0; d < rank_x; d++) - opts.defines.push_back({"TS" + std::to_string(d), TS[d]}); - - std::vector axy; - for(int d = 0; d < rank_x; d++) - if(d != axis) - axy.push_back(d); - for(int d = 0; d < rank_y; d++) - opts.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}}); - for(int d = 0; d < rank_y; d++) - opts.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}}); - - std::string RED = ""; - for(int n = 0; n < rank_x; n++){ - if(n > 0) - RED += ", "; - RED += (n==axis) ? to_str(op) : ":"; - } - opts.defines.push_back({"RED", {RED}}); - opts.num_warps = {2}; - - // kernel - rt::function function(src::reduce_nd[rank_x - 1], opts); - - // input buffers - auto dx = std::unique_ptr(drv::buffer::create(context, size_x*dtsize)); - auto dy = std::unique_ptr(drv::buffer::create(context, size_y*dtsize)); - - // grid - std::stringstream oss; - rt::add_arg(oss, *dx->cu()); - rt::add_arg(oss, *dy->cu()); - rt::add_arg(oss, (uint32_t)shape_x[0]); - if(shape_x.size() > 1) rt::add_arg(oss, (uint32_t)shape_x[1]); - if(shape_x.size() > 2) rt::add_arg(oss, (uint32_t)shape_x[2]); - std::vector ts = {"TS0", "TS1", "TS2"}; - auto grid = grid_nd(shape_x, ts); - - // metrics - if(mode == BENCH){ - auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); - bench.push_back(gbps(triton_ns)); - } - -// rt::options_t opt; -// for(auto &x: opts.defines) -// opt.defines[x.first] = x.second[0]; -// opt.num_warps = 1; -// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl; - - - // test triton - if(mode == TEST){ - std::vector hy(size_y); - std::vector ry(size_y); - std::vector hx(size_x); - init_zeros(hy); - init_rand(hx); - stream->write(&*dx, true, 0, hx); - function((void**)oss.str().data(), oss.str().size(), grid, stream, device); - stream->synchronize(); - stream->read(&*dy, true, 0, hy); - cc_reduce_nd(ry, hx, op, axis, shape_x); - test = testing::diff(hy, ry); - } -} - -bool do_test(drv::context* context, drv::stream* stream, std::vector shape, int axis, reduce_op_t op, int nwarp){ - std::vector bench; - bool test; - std::vector> TSS; - for(int32_t d: shape) - TSS.push_back({std::to_string(d)}); - triton_reduce_nd(context, stream, shape, axis, op, {0, 1, 2}, {0, 1, 2}, TSS, TEST, bench, test); - return test; -} diff --git a/tests/common/src/conv.h b/tests/common/src/conv.h deleted file mode 100644 index ace395575..000000000 --- a/tests/common/src/conv.h +++ /dev/null @@ -1,134 +0,0 @@ -namespace src { - - const char *conv = -R"( - __global__ void conv(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), - float alpha, - // equivalent matmul - int M __retune, - int N __retune, - int K __retune, - // convolution properties - int pad_h, int pad_w, int stride_h, int stride_w, - // pointer increment - int *ADELTA, - // memory strides - int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8), - int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8), - int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - // reduction splitting - K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; - - // unpack aggregate rows - // m = (z, p, q) - int rq[TM] = rm % QQ; - int rzp[TM] = rm / QQ; - int rp[TM] = rzp % PP; - int rz[TM] = rzp / PP; - // unpack aggregate reduction - // k = (ci, r, s) - int rs [TK] = rk % SS; - int rcir[TK] = rk / SS; - int rr [TK] = rcir % RR; - int rci [TK] = rcir / RR; - - // padding / striding - int rh_0[TM] = rp * stride_h - pad_h; - int rw_0[TM] = rq * stride_w - pad_w; - int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :]; - int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :]; - - // pointers to lhs - int offa[TM, TK] = rz [:, newaxis] * lda_z + - rci[newaxis, :] * lda_ci + - rh * lda_h + - rw * 1; - TYPE* pa[TM, TK] = A + offa; - // pointers to rhs - int offb[TK, TN] = rci[:, newaxis] * ldb_ci + - rr [:, newaxis] * ldb_r + - rs [:, newaxis] * ldb_s + - rn [newaxis, :] * 1; - TYPE* pb[TK, TN] = B + offb; - // pointers to delta - int* padelta[TK] = ADELTA + rk; - int adelta[TK] = *padelta; - - // prefetches operands - bool checkam[TM, TK] = rm[:, newaxis] < M; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - bool checkbn[TK, TN] = rn[newaxis, :] < N; - bool checkbk[TK, TN] = rk[:, newaxis] < K; - bool checkb[TK, TN] = checkbn && checkbk; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ - acc += a @ b; - // increment A - pa += adelta[newaxis, :]; - // bounds-checking A - rk += TK; - rs = rk % SS; - rcir = rk / SS; - rr = rcir % RR; - rh = rh_0[:, newaxis] + rr[newaxis, :]; - rw = rw_0[:, newaxis] + rs[newaxis, :]; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - // increment B - pb += TK * ldb_s; - // bounds-checking B - bool checkbk = k > TK; - bool checkb[TK, TN] = checkbn && checkbk; - a = checka ? *pa : 0; - b = *?(checkb)pb; - padelta += TK; - adelta = *padelta; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; - - // epilogue - rm = ridx * TM + 0 ... TM; - rn = ridy * TN + 0 ... TN; - rq = rm % QQ; - rzp = rm / QQ; - rp = rzp % PP; - rz = rzp / PP; - int offc[TM, TN] = rz [:, newaxis] * ldc_z + - rn [newaxis, :] * ldc_co + - rp [:, newaxis] * ldc_p + - rm [:, newaxis] * 1; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N; - -#if (TZ==1) - *?(checkc) pc = c; -#else - // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); -#endif -} -)"; - -} diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h deleted file mode 100644 index 2b9eb6cdc..000000000 --- a/tests/common/src/copy.h +++ /dev/null @@ -1,67 +0,0 @@ -#ifndef _TRITON_TEST_SRC_COPY_H_ -#define _TRITON_TEST_SRC_COPY_H_ - -namespace src { - - const char *copy1d = -R"( -void copy1d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __readonly __aligned(16), - int S0 __retune) { - int pid0 = get_program_id(0); - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - TYPE* px[TS0] = X + rs0; - TYPE* py[TS0] = Y + rs0; - *py = *px; -} -)"; - - const char *copy2d = -R"( -void copy2d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __writeonly __aligned(16), - int S0 __multipleof(8) __retune, - int S1 __multipleof(8) __retune) { - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - int rs1[TS1] = pid1 * TS1 + 0 ... TS1; - bool in_bounds[TS0, TS1] = rs0[:, newaxis] < S0 && rs1[newaxis, :] < S1; - TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0 + rs1[newaxis, :] * STRIDE_XS1; - TYPE* py[TS0, TS1] = Y + rs0[:, newaxis] * STRIDE_YS0 + rs1[newaxis, :] * STRIDE_YS1; - *?(in_bounds)py = *?(in_bounds)px; -} -)"; - - const char *copy3d = -R"( -void copy3d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __writeonly __aligned(16), - int S0 __multipleof(8) __retune, - int S1 __multipleof(8) __retune, - int S2 __multipleof(8) __retune) { - // program id - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int pid2 = get_program_id(2); - // ranges - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - int rs1[TS1] = pid1 * TS1 + 0 ... TS1; - int rs2[TS2] = pid2 * TS2 + 0 ... TS2; - // X pointers - TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0 - + rs1[newaxis, :, newaxis] * STRIDE_XS1 - + rs2[newaxis, newaxis, :] * STRIDE_XS2; - // Y pointers - TYPE* py[TS0, TS1, TS2] = Y + rs0[:, newaxis, newaxis] * STRIDE_YS0 - + rs1[newaxis, :, newaxis] * STRIDE_YS1 - + rs2[newaxis, newaxis, :] * STRIDE_YS2; - *py = *px; -} -)"; - - const char* copy_nd[] = {copy1d, copy2d, copy3d}; - -} - -#endif diff --git a/tests/common/src/reduce.h b/tests/common/src/reduce.h deleted file mode 100644 index 508ce896b..000000000 --- a/tests/common/src/reduce.h +++ /dev/null @@ -1,58 +0,0 @@ -namespace src { - - const char *reduce1d = -R"( -void reduce1d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __readonly __aligned(16), - int S0) { - int pid0 = get_program_id(0); - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - TYPE* px[TS0] = X + rs0; - *Y = (*px)[RED]; -} -)"; - - - const char *reduce2d = -R"( -void reduce2d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __writeonly __aligned(16), - int S0, int S1) { - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - int rs1[TS1] = pid1 * TS1 + 0 ... TS1; - TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0 - + rs1[newaxis, :] * STRIDE_XS1; - TYPE* py[TY0] = Y + RY0 * STRIDE_YS0; - *py = (*px)[RED]; -} -)"; - - const char *reduce3d = -R"( -void reduce2d(TYPE * X __noalias __readonly __aligned(16), - TYPE * Y __noalias __writeonly __aligned(16), - int S0, int S1, int S2) { - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int pid2 = get_program_id(2); - int rs0[TS0] = pid0 * TS0 + 0 ... TS0; - int rs1[TS1] = pid1 * TS1 + 0 ... TS1; - int rs2[TS2] = pid2 * TS2 + 0 ... TS2; - // input pointers - TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0 - + rs1[newaxis, :, newaxis] * STRIDE_XS1 - + rs2[newaxis, newaxis, :] * STRIDE_XS2; - // output pointers - TYPE* py[TY0, TY1] = Y + RY0[:, newaxis] * STRIDE_YS0 - + RY1[newaxis, :] * STRIDE_YS1; - // write-back - *py = (*px)[RED]; -} -)"; - - const char* reduce_nd[] = {reduce1d, reduce2d, reduce3d}; - - -} diff --git a/tests/common/util.h b/tests/common/util.h deleted file mode 100644 index 80774a3b5..000000000 --- a/tests/common/util.h +++ /dev/null @@ -1,225 +0,0 @@ -#pragma once - -#ifndef _TRITON_TESTS_UTIL_H -#define _TRITON_TESTS_UTIL_H - -#include -#include -#include "triton/runtime/function.h" - -namespace drv = triton::driver; -namespace rt = triton::runtime; - -/* ------------------------ - * Launch Grid - * ------------------------ */ - -inline size_t ceil(size_t x, size_t y) { - return (x + y - 1) / y; -} - -inline rt::function::grid_fn_ty grid1d(size_t N) { - return [N](const rt::options_t& x) { - return rt::grid_t{ceil(N, x.D("TN"))}; - }; -} - -inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { - return [M, N](const rt::options_t& x) { - return rt::grid_t{ceil(M, x.D("TM")), - ceil(N, x.D("TN"))}; - }; -} - -inline rt::function::grid_fn_ty grid_nd(const std::vector &shape, - const std::vector& ts) { - return [&shape, &ts](const rt::options_t& x) { - rt::grid_t ret; - for(size_t d = 0; d < shape.size(); d++) - ret.push_back(ceil(shape[d], x.D(ts[d]))); - return ret; - }; -} - -inline std::vector> tile_nd(size_t rank) { - assert(rank <= 3); - if(rank == 1) - return {{"128", "256", "512", "1024"}}; - if(rank == 2) - return {{"16", "32", "64"}, - {"16", "32", "64"}}; - if(rank == 3) - return {{"4", "16", "32"}, - {"4", "16", "32"}, - {"4", "16", "32"}}; - return {}; -} - -/* ------------------------ - * Tensor Initialization - * ------------------------ */ - -template -void init_rand(std::vector& x) { - for(size_t i = 0; i < x.size(); i++) - x[i] = i; -} - -template -void init_zeros(std::vector& x) { - for(size_t i = 0; i < x.size(); i++) - x[i] = 0; -} - -template -void init_ones(std::vector& x) { - for(size_t i = 0; i < x.size(); i++) - x[i] = 1; -} - -/* ------------------------ - * Loop Nests - * ------------------------ */ - -void _loop_nest(std::vector const & ranges, - std::function const &)> const & f){ - int D = ranges.size(); - std::vector values(D, 0); - // Start with innermost loop - int i = D - 1; - while(true){ - // Execute function - f(values); - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; - } -} - -/* ----------------------- - * TENSOR INDEXING - * ----------------------- */ - -enum order_t { - ROWMAJOR, - COLMAJOR -}; - - - - -int offset(const std::vector& idx, const std::vector& shapes) { - int result = idx[0]; - int ld = 1; - for(int i = 1; i < idx.size(); i++){ - ld *= shapes[i - 1]; - result += idx[i]*ld; - } - return result; -} - -/* ----------------------- - * REDUCTION HELPERS - * ----------------------- */ - -enum reduce_op_t { - ADD, - MAX, - MIN -}; - -std::string to_str(reduce_op_t op) { - switch (op) { - case ADD: return "+"; - case MAX: return "max"; - case MIN: return "min"; - default: break; - } - assert(false); - return ""; -} - -template -std::function get_accumulator(reduce_op_t op) { - switch (op) { - case ADD: return [](T x, T y) { return x + y; }; - case MAX: return [](T x, T y) { return std::max(x, y); }; - case MIN: return [](T x, T y) { return std::min(x, y); }; - default: break; - } - assert(false); - return std::function(); -} - - -/* ----------------------- - * TENSOR COMPARISON - * ----------------------- */ -namespace testing { - -template -bool diff(const std::vector& hc, const std::vector& rc) { -if(hc.size() != rc.size()) - return false; -for(size_t i = 0; i < hc.size(); i++) - if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ - std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; - return false; - } - return true; -} - -} - -/* ----------------------- - * PRETTY PRINTING - * ----------------------- */ - -namespace aux{ -template struct seq{}; - -template -struct gen_seq : gen_seq{}; - -template -struct gen_seq<0, Is...> : seq{}; - -template -void print_tuple(std::basic_ostream& os, Tuple const& t, seq){ - using swallow = int[]; - (void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get(t)), 0)...}; -} -} // aux:: - - -template -auto operator<<(std::basic_ostream& os, std::tuple const& t) - -> std::basic_ostream& -{ - aux::print_tuple(os, t, aux::gen_seq()); - return os; -} - -template -std::basic_ostream& operator<<(std::basic_ostream& os, const std::vector& vec) { - os << "{"; - for(size_t i = 0; i < vec.size(); i++){ - if(i > 0) - os << ", "; - os << vec[i]; - } - os << "}"; - return os; -} - -template -std::basic_ostream& operator<<(std::basic_ostream& os, reduce_op_t op) { - return os << to_str(op); -} - - - -#endif diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt deleted file mode 100644 index 6f397badb..000000000 --- a/tests/unit/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -foreach(PROG dot copy reduce) - set(TARGET unit_${PROG}) - add_executable(${TARGET} ${PROG}.cc) - set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) - target_link_libraries(${TARGET} triton dl) -endforeach(PROG) diff --git a/tests/unit/copy.cc b/tests/unit/copy.cc deleted file mode 100644 index 13a7b6270..000000000 --- a/tests/unit/copy.cc +++ /dev/null @@ -1,60 +0,0 @@ -#include -#include -#include "copy.h" -#include "triton/driver/backend.h" - - - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple, std::vector, std::vector, std::vector> config_t; - std::vector configs; - // 1D -// configs.push_back({{65536}, {32}, {0}, {0}}); - configs.push_back({{65536}, {128}, {0}, {0}}); - configs.push_back({{65536}, {512}, {0}, {0}}); - configs.push_back({{65536}, {1024}, {0}, {0}}); - // 2D - configs.push_back({{256, 256}, {16, 16}, {0, 1}, {0, 1}}); - configs.push_back({{256, 256}, {16, 64}, {0, 1}, {0, 1}}); - configs.push_back({{256, 256}, {64, 16}, {0, 1}, {0, 1}}); - configs.push_back({{256, 256}, {64, 64}, {0, 1}, {0, 1}}); - configs.push_back({{256, 256}, {16, 16}, {0, 1}, {1, 0}}); - configs.push_back({{256, 256}, {16, 64}, {0, 1}, {1, 0}}); - configs.push_back({{256, 256}, {64, 16}, {0, 1}, {1, 0}}); - configs.push_back({{256, 256}, {64, 64}, {0, 1}, {1, 0}}); - configs.push_back({{256, 256}, {16, 16}, {1, 0}, {0, 1}}); - configs.push_back({{256, 256}, {16, 64}, {1, 0}, {0, 1}}); - configs.push_back({{256, 256}, {64, 16}, {1, 0}, {0, 1}}); - configs.push_back({{256, 256}, {64, 64}, {1, 0}, {0, 1}}); - configs.push_back({{256, 256}, {64, 64}, {1, 0}, {1, 0}}); - configs.push_back({{256, 256}, {16, 64}, {1, 0}, {1, 0}}); - configs.push_back({{256, 256}, {64, 16}, {1, 0}, {1, 0}}); - configs.push_back({{256, 256}, {64, 64}, {1, 0}, {1, 0}}); - // 3D - std::vector> xx_idx = {{0, 1, 2}, {2, 1, 0}, {1, 0, 2}}; - std::vector> yy_idx = {{0, 1, 2}, {2, 1, 0}, {1, 0, 2}}; - for(const auto& x_idx: xx_idx) - for(const auto& y_idx: yy_idx){ - configs.push_back({{64, 64, 32}, {16, 4, 8}, x_idx, y_idx}); - configs.push_back({{64, 64, 32}, {8, 16, 2}, x_idx, y_idx}); - configs.push_back({{64, 64, 32}, {32, 2, 2}, x_idx, y_idx}); - configs.push_back({{64, 64, 32}, {16, 64, 4}, x_idx, y_idx}); - } - // testing - std::vector shape, tile; - std::vector ord_x, ord_y; - bool result = true; - for(const auto& c: configs){ - std::tie(shape, tile, ord_x, ord_y) = c; - bool pass = test_copy_nd(context, stream, FLOAT, shape, tile, ord_x, ord_y); - result = result && pass; - std::cout << "// " << c << ", " << pass << std::endl; - } - return result; -} - - diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc deleted file mode 100644 index b7931e248..000000000 --- a/tests/unit/dot.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "dot.h" -#include "util.h" - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to test - typedef std::tuple config_t; - std::vector configs; - for(dtype_t dtype: std::vector{FLOAT, HALF}) - for(bool AT: std::vector{false, true}) - for(bool BT: std::vector{false, true}){ - // 1 warp - configs.push_back({dtype, AT, BT, 16, 16, 16, 1}); - configs.push_back({dtype, AT, BT, 32, 16, 16, 1}); - configs.push_back({dtype, AT, BT, 16, 32, 16, 1}); - configs.push_back({dtype, AT, BT, 16, 16, 32, 1}); - configs.push_back({dtype, AT, BT, 32, 16, 32, 1}); - configs.push_back({dtype, AT, BT, 16, 32, 32, 1}); - if(dtype == HALF){ - configs.push_back({dtype, AT, BT, 16, 64, 16, 1}); - configs.push_back({dtype, AT, BT, 16, 16, 64, 1}); - configs.push_back({dtype, AT, BT, 64, 16, 64, 1}); - configs.push_back({dtype, AT, BT, 16, 64, 64, 1}); - } - // 2 warps - configs.push_back({dtype, AT, BT, 64, 32, 64, 2}); - configs.push_back({dtype, AT, BT, 32, 64, 64, 2}); - configs.push_back({dtype, AT, BT, 64, 32, 16, 2}); - configs.push_back({dtype, AT, BT, 32, 64, 16, 2}); - configs.push_back({dtype, AT, BT, 128, 32, 32, 2}); - configs.push_back({dtype, AT, BT, 32, 128, 32, 2}); - // 4 warps - configs.push_back({dtype, AT, BT, 128, 64, 16, 4}); - configs.push_back({dtype, AT, BT, 64, 128, 16, 4}); - configs.push_back({dtype, AT, BT, 128, 32, 32, 4}); - configs.push_back({dtype, AT, BT, 32, 128, 32, 4}); - if(dtype == HALF){ - configs.push_back({dtype, AT, BT, 128, 32, 64, 4}); - configs.push_back({dtype, AT, BT, 32, 128, 64, 4}); - } - // 8 warps - configs.push_back({dtype, AT, BT, 128, 256, 16, 8}); - configs.push_back({dtype, AT, BT, 256, 128, 16, 8}); - if(dtype == HALF){ - configs.push_back({dtype, AT, BT, 256, 128, 32, 8}); - configs.push_back({dtype, AT, BT, 256, 128, 32, 8}); - } - - }; - // test - dtype_t dtype; - bool AT, BT; - int M, N, K, TM, TN, TK, nwarp; - for(const auto& c: configs){ - std::tie(dtype, AT, BT, TM, TN, TK, nwarp) = c; - M = TM; - N = TN; - K = TK; - std::cout << "Testing " << c << " ... " << std::flush; - if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) - std::cout << " Pass! " << std::endl; - else{ - std::cout << " Fail! " << std::endl; - } - } -} diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc deleted file mode 100644 index 757582712..000000000 --- a/tests/unit/reduce.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include -#include -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include "cuda/cublas.h" -#include "reduce.h" -#include "util.h" - - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple, int, reduce_op_t> config_t; - std::vector configs = { - config_t{{64}, 0, ADD}, - config_t{{128}, 0, MIN}, - config_t{{32, 32}, 0, MAX}, - config_t{{32, 32}, 1, ADD}, - config_t{{32, 64}, 0, ADD}, - config_t{{64, 32}, 1, ADD}, - config_t{{8, 8, 4}, 2, ADD}, - config_t{{8, 8, 4}, 0, ADD}, - config_t{{8, 8, 4}, 1, ADD} - }; - // does the work - int axis; - std::vector shape; - reduce_op_t op; - for(const auto& c: configs){ - std::tie(shape, axis, op) = c; - std::cout << "Testing " << c << " ... " << std::flush; - if(do_test(context, stream, shape, axis, op, 1)) - std::cout << " Pass! " << std::endl; - else - std::cout << " Fail! " << std::endl; - } -} diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc new file mode 100644 index 000000000..6606ad22f --- /dev/null +++ b/tutorials/01-matmul.cc @@ -0,0 +1,245 @@ +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include +#include +#include +#include +#include +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include "triton/tools/bench.hpp" +#include "triton/external/half.hpp" +#include "triton/runtime/function.h" +#include +#include +#include "triton/runtime/function.h" + +namespace drv = triton::driver; +namespace rt = triton::runtime; + +namespace src { + + const char *dot = +R"( +#define STM 8 +#define STN 8 + +__global__ void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __aligned(16), + float alpha, + int M __retune, + int N __retune, + int K __retune __multipleof(16), + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc __multipleof(8), + int* locks) { + // prologue + int pid = get_program_id(0); + int pidz = get_program_id(2); + int gridm = (M + TM - 1) / TM; + int gridn = (N + TN - 1) / TN; + int width = STM*gridn; + int stm = pid / width; + int RSTM = min(gridm - stm*STM, STM); + int stn = (pid % width) / (RSTM*STN); + int RSTN = min(gridn - stn*STN, STN); + int laneid = pid % (RSTM * RSTN); + int lanem = laneid / RSTN; + int lanen = laneid % RSTN; + int pidm = stm*STM + lanem; + int pidn = stn*STN + lanen; + int rm[TM] = pidm * TM + 0 ... TM; + int rn[TN] = pidn * TN + 0 ... TN; + + // reduction splitting + K = K / TZ; + int rk[TK] = pidz * K + 0 ... TK; + // pointers to operands + int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; + int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; + TYPE* pa[TM, TK] = A + offa; + TYPE* pb[TK, TN] = B + offb; + // prefetches operands + bool checka[TM, TK] = rk[newaxis, :] < K; + bool checkb[TK, TN] = rk[:, newaxis] < K; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + // reduction loop + float acc[TM, TN] = 0; + for(int k = K; k > 0; k -= TK){ + bool checka[TM, TK] = k > TK; + bool checkb[TK, TN] = k > TK; + acc += a @ b; + a = *?(checka)pa; + b = *?(checkb)pb; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + } + acc = acc * alpha; + TYPE c[TM, TN] = acc; + + // epilogue + int rcm[TM] = pidm * TM + 0 ... TM; + int rcn[TN] = pidn * TN + 0 ... TN; + int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; + TYPE* pc[TM, TN] = C + offc; + bool checkc[TM, TN] = rcm[:, newaxis] < M && + rcn[newaxis, :] < N; +#if (TZ==1) + *?(checkc) pc = c; +#else + // accumulate partial result using spin-locks + int *plock = locks + rid; + int *pcount = plock + get_num_programs(0) * get_num_programs(1); + for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc) pc = c; + else + *?(checkc) pc = c + *?(checkc)pc; + atomic_xchg(pcount, (count + 1) % TZ); + atomic_xchg(plock, 0); +#endif +} +)"; + +} + +enum dtype_t { + FLOAT, + HALF, + DOUBLE +}; + +template +struct to_string; + +template<> struct to_string{ + static constexpr const char* value = "half"; +}; + +template<> struct to_string{ + static constexpr const char* value = "float"; +}; + +template<> struct to_string{ + static constexpr const char* value = "double"; +}; + +template +void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + const std::vector& a_order, const std::vector& b_order, + std::vector& bench, bool &test){ + std::string ty = to_string::value; + size_t dt_nbytes = sizeof(T); + drv::device* device = context->device(); + int32_t lda = (AT ^ a_order[0]==1) ? K : M; + int32_t ldb = (BT ^ b_order[0]==1) ? N : K; + int32_t ldc = N; + std::vector sa = { "1", "lda" }; + std::vector sb = { "1", "ldb" }; + // inputs + auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); + auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); + auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); + auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); + // initialize buffers + std::vector hc(M*N); + std::vector ha(M*K); + std::vector hb(K*N); + for(size_t i = 0; i < ha.size(); i++) + ha[i] = (float)rand()/RAND_MAX; + for(size_t i = 0; i < hb.size(); i++) + hb[i] = (float)rand()/RAND_MAX; + // copy buffer + stream->write(&*da, true, 0, ha); + stream->write(&*db, true, 0, hb); + + // macros + rt::options_space_t opts; + // A access patterns + opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); + opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); + // B access patterns + opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); + opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); + // data-type + opts.defines.push_back({"TYPE", {ty}}); + // tile sizes + opts.defines.push_back({"TM", {"128"}}); + opts.defines.push_back({"TN", {"128"}}); + opts.defines.push_back({"TK", {"32"}}); + opts.defines.push_back({"TZ", {"1"}}); + opts.num_warps = {4}; + + // arguments + std::stringstream oss; + rt::add_arg(oss, *da->cu()); + rt::add_arg(oss, *db->cu()); + rt::add_arg(oss, *dc->cu()); + rt::add_arg(oss, (float)1); + rt::add_arg(oss, M); + rt::add_arg(oss, N); + rt::add_arg(oss, K); + rt::add_arg(oss, lda); + rt::add_arg(oss, ldb); + rt::add_arg(oss, ldc); + rt::add_arg(oss, *dlocks->cu()); + // kernel + rt::function function(src::dot, opts, device); + // grid + auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; + auto grid = [ceil, M, N](const rt::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM"))* + ceil(N, x.D("TN")), + (size_t)x.D("TZ")}; + }; + + // metrics + auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; + double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream); + bench.push_back(tflops(triton_ns)); +} + +std::vector bench_dot(drv::context* context, drv::stream* stream, + dtype_t dtype, bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + const std::vector& a_order, const std::vector& b_order) { + std::vector bench; + bool test; + switch(dtype){ + case HALF: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; + case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; + case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break; + default: break; + } + return bench; +} + + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); + // shapes to benchmark + typedef std::tuple, bool, bool, int, int, int> config_t; + std::vector configs = { + {{1, 0}, false, false, 8192, 8192, 8192} + }; + // does the work + std::vector ord; + bool AT, BT; + int32_t M, N, K; + for(const auto& c: configs){ + std::tie(ord, AT, BT, M, N, K) = c; + std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K ; + for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord)) + std::cout << ", " << perf << std::flush; + std::cout << std::endl; + } +} diff --git a/tests/bench/CMakeLists.txt b/tutorials/CMakeLists.txt similarity index 75% rename from tests/bench/CMakeLists.txt rename to tutorials/CMakeLists.txt index f531fadf4..1dd6c9b8f 100644 --- a/tests/bench/CMakeLists.txt +++ b/tutorials/CMakeLists.txt @@ -1,5 +1,5 @@ -foreach(PROG dot copy conv) - set(TARGET bench_${PROG}) +foreach(PROG 01-matmul) + set(TARGET ${PROG}) add_executable(${TARGET} ${PROG}.cc) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) target_link_libraries(${TARGET} triton dl)