History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space

This commit is contained in:
Philippe Tillet
2021-07-27 12:38:38 -07:00
commit 6d7cf35123
202 changed files with 94034 additions and 0 deletions

45
CMakeLists.txt Normal file
View File

@@ -0,0 +1,45 @@
cmake_minimum_required(VERSION 2.8)
project(triton)
include(CTest)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
# Options
option(BUILD_TESTS "Build C++ Triton tests" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# LLVM
find_package(LLVM REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
# Default build type
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Default build type: Release")
set(CMAKE_BUILD_TYPE "Release")
endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# Tests
if(BUILD_TESTS)
message(STATUS "Adding C++ tests")
add_subdirectory(tests)
endif()
# Python module
if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# PyBind11 wrapper source file
file(GLOB_RECURSE PYTHON_SRC python/src/bindings.cc)
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS})
endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
link_directories(${LLVM_LIBRARY_DIRS})
target_link_libraries(triton ${LLVM_LIBRARIES})

26
LICENSE Executable file
View File

@@ -0,0 +1,26 @@
/* Copyright 2018-2019 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
// The compiler front-end is based on a modified version of WGTCC
// https://github.com/wgtdkp/wgtcc
// Copyright (c) 2016 wgtdkp

37
README.md Normal file
View File

@@ -0,0 +1,37 @@
# Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write custom ops at higher productivity than CUDA, but also with much higher flexibility than [TVM](https://github.com/apache/incubator-tvm).
The main scope of Triton at the moment are:
- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level using numpy-like extensions of the C language.
- **Triton-IR**: An intermediate-representation for optimizing multi-dimensional array operations in linear algebra programs
- **Triton-JIT**: An optimizing just-in-time compiler for Triton-C, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)) and PTX (e.g., [ISAAC](https://github.com/ptillet/isaac)). This includes transparent support for mixed-precision and Tensor Cores.
Bindings for **automatic** PyTorch custom op generations are included in - **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc.
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please cite us if you use our work!
## Installation
Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM code-generator. However, at the moment it relies on LLVM-8.0+ for PTX code generation.
```
sudo apt-get install llvm-8-dev
git clone https://github.com/ptillet/triton.git;
cd triton/python/;
python setup.py develop;
cd examples;
python einsum.py
```
## Tutorials
- [The Triton-C language](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
- [The PyTriton API](https://github.com/ptillet/triton/blob/master/docs/pytriton.md)
- Extended Einstein Summations (coming soon...)
- The Triton-IR representation (coming soon...)
- The Triton-JIT compiler (coming soon...)

166
cmake/FindLLVM.cmake Normal file
View File

@@ -0,0 +1,166 @@
# - Find LLVM headers and libraries.
# This module locates LLVM and adapts the llvm-config output for use with
# CMake.
#
# A given list of COMPONENTS is passed to llvm-config.
#
# The following variables are defined:
# LLVM_FOUND - true if LLVM was found
# LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers.
# LLVM_HOST_TARGET - Target triple used to configure LLVM.
# LLVM_INCLUDE_DIRS - Directory containing LLVM include files.
# LLVM_LDFLAGS - Linker flags to add when linking against LLVM
# (includes -LLLVM_LIBRARY_DIRS).
# LLVM_LIBRARIES - Full paths to the library files to link against.
# LLVM_LIBRARY_DIRS - Directory containing LLVM libraries.
# LLVM_ROOT_DIR - The root directory of the LLVM installation.
# llvm-config is searched for in ${LLVM_ROOT_DIR}/bin.
# LLVM_VERSION_MAJOR - Major version of LLVM.
# LLVM_VERSION_MINOR - Minor version of LLVM.
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
#
# Note: The variable names were chosen in conformance with the offical CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian
# derivatives.
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the
# system default locations such as /usr/local/bin. Executing find_program()
# multiples times is the approach recommended in the docs.
set(llvm_config_names llvm-config-9 llvm-config-9.0 llvm-config90
llvm-config-8 llvm-config-8.0 llvm-config80
llvm-config)
find_program(LLVM_CONFIG
NAMES ${llvm_config_names}
PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH
DOC "Path to llvm-config tool.")
find_program(LLVM_CONFIG NAMES ${llvm_config_names})
# Prints a warning/failure message depending on the required/quiet flags. Copied
# from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed.
macro(_LLVM_FAIL _msg)
if(LLVM_FIND_REQUIRED)
message(FATAL_ERROR "${_msg}")
else()
if(NOT LLVM_FIND_QUIETLY)
message(STATUS "${_msg}")
endif()
endif()
endmacro()
if(NOT LLVM_CONFIG)
if(NOT LLVM_FIND_QUIETLY)
message(WARNING "Could not find llvm-config (LLVM >= ${LLVM_FIND_VERSION}). Try manually setting LLVM_CONFIG to the llvm-config executable of the installation to use.")
endif()
else()
macro(llvm_set var flag)
if(LLVM_FIND_QUIETLY)
set(_quiet_arg ERROR_QUIET)
endif()
set(result_code)
execute_process(
COMMAND ${LLVM_CONFIG} --${flag}
RESULT_VARIABLE result_code
OUTPUT_VARIABLE LLVM_${var}
OUTPUT_STRIP_TRAILING_WHITESPACE
${_quiet_arg}
)
if(result_code)
_LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'")
else()
if(${ARGV2})
file(TO_CMAKE_PATH "${LLVM_${var}}" LLVM_${var})
endif()
endif()
endmacro()
macro(llvm_set_libs var flag)
if(LLVM_FIND_QUIETLY)
set(_quiet_arg ERROR_QUIET)
endif()
set(result_code)
execute_process(
COMMAND ${LLVM_CONFIG} --${flag} ${LLVM_FIND_COMPONENTS}
RESULT_VARIABLE result_code
OUTPUT_VARIABLE tmplibs
OUTPUT_STRIP_TRAILING_WHITESPACE
${_quiet_arg}
)
if(result_code)
_LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'")
else()
file(TO_CMAKE_PATH "${tmplibs}" tmplibs)
string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs})
endif()
endmacro()
llvm_set(VERSION_STRING version)
llvm_set(CXXFLAGS cxxflags)
llvm_set(HOST_TARGET host-target)
llvm_set(INCLUDE_DIRS includedir true)
llvm_set(ROOT_DIR prefix true)
llvm_set(ENABLE_ASSERTIONS assertion-mode)
# The LLVM version string _may_ contain a git/svn suffix, so cut that off
string(SUBSTRING "${LLVM_VERSION_STRING}" 0 5 LLVM_VERSION_BASE_STRING)
# Versions below 4.0 do not support components debuginfomsf and demangle
if(${LLVM_VERSION_STRING} MATCHES "^3\\..*")
list(REMOVE_ITEM LLVM_FIND_COMPONENTS "debuginfomsf" index)
list(REMOVE_ITEM LLVM_FIND_COMPONENTS "demangle" index)
endif()
# Versions below 8.0 not supported
if(${LLVM_VERSION_STRING} MATCHES "^[3-7]\\..*")
message(FATAL_ERROR "LLVM version below 8.0 not supported")
endif()
llvm_set(LDFLAGS ldflags)
# In LLVM 3.5+, the system library dependencies (e.g. "-lz") are accessed
# using the separate "--system-libs" flag.
llvm_set(SYSTEM_LIBS system-libs)
string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}")
llvm_set(LIBRARY_DIRS libdir true)
llvm_set_libs(LIBRARIES libs)
# LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0
# but code for it is not in shared library
if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen")
if (NOT "${LLVM_LIBRARIES}" MATCHES "LLVMTableGen")
set(LLVM_LIBRARIES "${LLVM_LIBRARIES};-lLLVMTableGen")
endif()
endif()
# Versions below 4.0 do not support llvm-config --cmakedir
if(${LLVM_VERSION_STRING} MATCHES "^3\\..*")
set(LLVM_CMAKEDIR ${LLVM_LIBRARY_DIRS}/cmake/llvm)
else()
llvm_set(CMAKEDIR cmakedir)
endif()
llvm_set(TARGETS_TO_BUILD targets-built)
string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD})
endif()
# Remove some clang-specific flags for gcc.
if(CMAKE_COMPILER_IS_GNUCXX)
string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
# Remove gcc-specific flags for clang.
if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS})
endif()
string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" )
string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" )
# Use the default CMake facilities for handling QUIET/REQUIRED.
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR LLVM_HOST_TARGET
VERSION_VAR LLVM_VERSION_STRING)

243
docs/pytriton.md Normal file
View File

@@ -0,0 +1,243 @@
# The PyTriton API
## <span style="color:darkred"> Table of Contents </span>
1. [Motivations](#motivations)
2. [Triton Functions](#pytriton-function)
1. [Creation of Triton Kernels](#creation-triton-kernels)
2. [Usage of Triton Kernels](#usage-triton-kernels)
3. [Integration with Automatic Differentiation](#autodiff)
1. [Basics](#autodiff-basics)
2. [Convenience](#autodiff-convenience)
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
The purpose of PyTriton is to provide an API for easily executing Triton-C kernels from PyTorch and Tensorflow. One of the main advantages of PyTriton is that it is framework agnostic: any custom op written using this API will be transparently compatible with both Tensorflow and PyTorch without any additional effort required, as will be shown in this tutorial.
Consider for example the following piece of code:
```python
import numpy as np
import triton
def run_tf():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True)
da, db = tf.gradients(c, [a, b])
# Run
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([da], feed_dict = {a: ha, b: hb})
def run_torch():
M, N, K = 128, 128, 128
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
c = triton.ops.dot(a, b, False, True)
c.backward()
da = a.grad.clone()
db = b.grad.clone()
## Run on tensorflow
# import tensorflow as tf
# run_tf()
## Run on pytorch
# import torch
# run_torch()
```
PyTriton works by detecting which frameworks are imported and automatically generating and just-in-time compiling C++ binding code for them. Specifically, the following chain of events is triggered when a Triton operation is executed:
1. The imported frameworks are detected
2. C++ binding code for Tensorflow or PyTorch is generated, compiled and cached.
3. The corresponding custom-op is automatically loaded from the generated .so file, and a framework-agnostic wrapper is created.
4. The wrapper is called and a tf.tensor or a torch.tensor is returned. In the case of Tensorflow, the gradient is also registered at this point if applicable
The remainder of this tutorial will show you how to re-implement the above `triton.ops.dot` operation from scratch.
## <span style="color:darkred"> PyTriton Functions </span> <a name="pytriton-function"></a>
The PyTriton API provides a `triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
```python
import triton
# Entry point
class _dot(triton.function):
@staticmethod
# Forward Pass
def forward(ctx, *args):
#...
@staticmethod
# Backward Pass
def backward(ctx, dy):
#...
```
### <span style="color:darkblue">Creation of Triton Kernels </span> <a name="creation-triton-kernel"></a>
PyTriton also provides a `triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton-C code derived at the end of the [previous tutorial](https://github.com/ptillet/triton/blob/master/docs/triton-c.md)
```
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
}
"""
kernel = triton.kernel(src, ['C'])
```
Note that the second argument to `triton.kernel` constructors indicates which of the operands our kernel function should return. Here, we only return `C`.
At this point, `kernel` is a callable object which takes the same signature as the `dot` function in our source code, except that pointers are treated as tensors:
```
[tensor, tensor, tensor, int, int, int, int, int, int]
```
### <span style="color:darkblue">Usage of Triton Kernels </span> <a name="usage-triton-kernels"></a>
However, in practice only A, B are provided by the user, and all the other `int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the `A` and `B` tensors, and then returns the results of a call to `kernel`:
```python
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [32, 64, 128], TN = [32, 64, 128], TK = [8], **macros)
```
While this code should be mostly self-explanatory, there are a few of noteworthy things worth pointing out
- `triton.shape` provides a framework-agnostic way to retrieve the shape of a tensor
- `triton.empty` creates an empty tensor of the specified dimensions
- `grid` corresponds to the grid with which our Triton kernel will be launched. Because in our case this grid depends on parametric tile variables, it is supplied as a function of compilation options `opt`, whose compile-time definition can be retrieved using `opt.d(name)`. Here, `opt.d('TM')` and `opt.d('TN')` retrieve the first and second tile dimension our kernel was compiled with. We also provide a helper `triton.cdiv` for ceil divisions.
- `macros` provides a list of preprocessor definitions to compile the kernel with. Alternatively, these can also be supplied as named argument to the `_dot.kernel`. We recall that lists can be supplied to the preprocessor, in which case an auto-tuning procedure will be triggered. Here, the value of `TM` and `TN` are both tuned between 32, 64 and 128.
## <span style="color:darkred"> Compatibility with Automatic Differentiation</span> <a name="autodiff"></a>
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
### <span style="color:darkblue"> Basics </span> <a name="autodiff-basics"></a>
PyTriton binds to Tensorflow's and PyTorch's automatic differentiation framework using a single, common API inspired by PyTorch. It consists of two static methods `forward` and `backward` that take a context as their first input:
```
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
else:
assert False
return da, db, None, None, None, None, None, None, None
```
### <span style="color:darkblue">Convenience </span> <a name="autodiff-convenience"></a>
Still like for PyTorch, a callable operation can be created using the `apply` method of our `triton.function` class. We wrap it as a module variable for convenience:
```python
dot = _dot.apply
```
And that's it! Our custom op is now created and ready to be used with both PyTorch and Tensorflow.

436
docs/triton-c.md Normal file
View File

@@ -0,0 +1,436 @@
# The Triton-C Programming Language
## <span style="color:darkred"> Table of Contents </span>
1. [Motivations](#motivations)
2. [Vector Addition](#vector-addition)
1. [Differences with CUDA](#differences-with-cuda)
2. [Advantages over CUDA](#advantages-over-cuda)
1. [Vectorization](#vectorization)
2. [Parameterization](#parameterization)
3. [Auto-Tuning](#auto-tuning)
3. [Matrix Transposition](#matrix-transposition)
1. [Compute Kernel](#trans-compute-kernel)
2. [The __multipleof Attribute](#trans-multipleof)
3. [Conditional Dereferencing](#conditional-dereferencing)
4. [Matrix Multiplication](#matrix-multiplication)
1. [Compute Kernel](#matmul-compute-kernel)
2. [Optimizations](#optimizations)
1. [Pre-Fetching](#pre-fetching)
1. [Rematerialization](#rematerialization)
3. [Fused Transpositions and Auto-Tuning](#fused-trans-autotuning)
## <span style="color:darkred"> Motivations </span> <a name="motivations"></a>
In C and C++, arrays and pointers have similar semantics. Indeed, there is no native way to manipulate statically shaped multi-dimensional arrays (beyond initialization) as a whole:
```c
// C99
float x[16][8] = {3.14};
float y[16][8] = {5.17};
// z = x + y
float z[16][8];
#pragma unroll
for(int i = 0; i < 16; i++)
#pragma unroll
for(int j = 0; j < 8; j++)
z[i][j] = x[i][j] + y[i][j];
```
While it does not seem like a big deal at first sight, there are two issues with this:
- **Ergonomics**: Of course, it is possible to simplify the above code using functions in C
```
float z[16][8];
add(z, x, y, 16, 8);
```
but this would be semantically different as the loops can no longer be unrolled due to their bounds being now dynamic arguments of the add function. This can be mitigated using templates metaprogramming (and operator overloads) in C++:
```c
// C++
template<typename T, int M, int N>
class matrix;
matrix<float, 16, 8> x = {3.14};
matrix<float, 16, 8> y = {5.17};
matrix<float, 16, 8> z = x + y;
```
While this is better and now equivalent to our initial code snippet, the syntax is not quite as ergonomically satisfying as what native syntactic support could provide:
```c
// Triton-C
float x[16, 8] = 3.14;
float y[16, 8] = 5.17;
// float z[8, 8] = x + y; // doesn't compile -- incompatible shapes!
float z[16, 8] = x + y;
float u[16] = z[:, +]; // sum along the second axis
float v[16, 32] = u[:, newaxis]; // broadcasting along the second axis
```
which is valid _Triton-C_.
- **Portability**: One other issue with our initial C program is that it is not portable. While it will run well on a single CPU thread, the operation `z = x + y` would underutilize a GPU Streaming Processor as it would execute on a single thread only. For this reason, it would have to be rewritten in CUDA as follows:
```
// CUDA
// Launch on a block of 16 x 8 threads
float x = 3.14;
float y = 5.17;
float z = x + y
```
In Triton-C, the same code can be used across many different platforms (only CPUs and GPUs are supported at the moment). Furthermore, Triton-C is single-threaded, hence easier to write than CUDA.
- **Performance**: Another issue with our initial C code snippet is its performance. Although the loops are unrolled, the program does not carry any data-flow information pertaining to array operations. This issue gets more and more problematic as programs get increasingly complex, eventually culminating in matrix multiplication being remarkably hard to optimize.
This can be worked around using heavy metaprogramming techniques (see [CUTLASS](https://github.com/NVIDIA/cutlass)), but even then programmers still have to allocate and synchronize shared memory manually and endure prohibitively long compilation procedures not easily amenable to auto-tuning. For these reasons, most Deep-Learning frameworks still rely heavily on highly optimized subroutines (e.g., BLAS), which makes the development of novel custom primitives time-consuming for experts and almost impossible for others.
Triton addresses this issue by relying on **Triton-IR**, an LLVM-like IR for array operations, and **Triton-JIT**, an optimizing compiler for Triton-IR. These two systems are, however, beyond the scope of this tutorial. More information can be found [here](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf).
_Note: You might be thinking that this is exactly what [MLIR](https://github.com/tensorflow/mlir) was made for... and you're right! You can conceptually think of Triton-IR as a dialect for MLIR, and Triton-C as a frontend for it. I would like to integrate Triton-IR into MLIR in the future; If you're interested in making this a thing, let me know._
## <span style="color:darkred"> Vector Addition </span> <a name="vector-addition"></a>
### <span style="color:darkblue"> Differences with CUDA </span> <a name="differences-with-cuda"></a>
Let's start it off by looking at a simple example. Vector addition, in its most trivial Triton-C implementation, can be written as follows:
```c
// Triton-C
// launched on a grid of (N / 32) programs of 1 thread each
__global__ void add(int N, float *a, float *b, float* c) {
int id = get_program_id(0);
int off[32] = id * 32 + (0 ... 32)
*(c + off) = *(a + off) + *(b + off);
}
```
For reference, here is an equivalent CUDA kernel (NVCC will generate the same PTX code as Triton-JIT on the above code):
```c
// CUDA
// launched on a grid of (N / 32) programs of 32 threads each
__global__ void add(int N, float *a, float *b, float *c) {
int off = blockIdx.x * 32 + threadIdx.x;
c[off] = a[off] + b[off];
}
```
As you can see, there are three main differences between our Triton-C kernel and the equivalent CUDA:
- **The programming model is different**.
While Triton-C and CUDA both use a Single-Program, Multiple-Data (SPMD) programming model, each Triton-C kernel is single-threaded.
Therefore, `get_program_id({0, 1, 2})` is equivalent to `blockIdx.{x, y, z}`, but there is no such thing as `blockDim` and `threadIdx`.
- **The semantics of arrays is different**
In the above Triton-C kernel, `off` is an array of 32 consecutive integers: `int off[32] = {id * 32 + 0, id * 32 + 1, ..., id * 32 + 31}`.
As a result, the statement: `c + off` implicitly broadcast `c` and creates an array of 32 pointers. This could also be done explicitly as follows:
```
float* c_broadcast[32] = c;
float* c_ptr[32] = c_broadcast + off; // c_ptr = c + off
```
- **The semantics of the subscript operator is different**.
n C/CUDA, subscripting can be used to offset and dereference a pointer, but in Triton-C it can only be used to index and broadcast an array (think NumPy).
### <span style="color:darkblue"> Advantages over CUDA </span> <a name="advantages-over-cuda"></a>
At this point, the advantages of Triton-C over CUDA may not be obvious. But they should become clearer and clearer as this tutorial progresses. First and foremost, the purpose of this subsection is to show how Triton can be used to optimize vector additions by automatically taking care of load/store vectorization, code parameterization and auto-tuning -- all of which require nontrivial implementation efforts in CUDA.
#### <span style="color:purple"> Vectorization </span> <a name="vectorization"></a>
On some hardware architectures, vectorizing load/store operations can lead to better memory utilization and, in turn, noticeable performance gains. In general, 128-bit memory transactions are favored, leading to the following CUDA kernel:
```c
// CUDA
// launched on a grid of (N / 128) programs of 32 threads each
__global__ void add(int N, float4 *a, float4 *b, float4 *c) {
int off = blockIdx.x * 32 + threadIdx.x;
c[off] = a[off] + b[off];
}
```
Or, for half-precision inputs:
```c
// CUDA
// launched on a grid of (N / 256) programs of 32 threads each
__global__ void add(int N, half8 *a, half8 *b, half8 *c) {
int off = blockIdx.x * 32 + threadIdx.x;
c[off] = a[off] + b[off];
}
```
Now this is a bit annoying, because as a programmer you have to keep track of not only the ideal vector size for each data-type (which might change in future GPU architectures), but also of how many elements are processed in each thread-block -- and adjust the grid size of the kernel accordingly! Not to mention that you may want to tune the thread-block size as well.
In Triton-C, this is not a problem as the compiler will figure out automatically when and where vectorization should be used, without any change in the source-code necessary.
#### <span style="color:purple"> Parameterization </span> <a name="parameterization"></a>
Specifically, the Triton compiler would refuse to 4-way vectorize our above compute kernel because it would require the array `int off[32]` to be distributed over 8 threads, which is less than a warp. Fortunately, it turns out that this problem can be easily solved using preprocessor directrives to _parameterize_ our kernel:
```c
// Triton-C
// launched on a grid of (N / SIZE) programs of 1 thread each
__global__ void add(int N, TYPE* a, TYPE* b, TYPE* c) {
int id = get_program_id(0);
int off[SIZE] = id * SIZE + (0 ... SIZE);
*(c + off) = *(a + off) + *(b + off);
}
// Not vectorized when compiled with -DSIZE=32 -DTYPE=float
// 4-Vectorized when compiled with -DSIZE=128 -DTYPE=float
// 8-Vectorized when compiled with -DSIZE=256 -DTYPE=half
```
Now, `TYPE` and `SIZE` are preprocessors macros which can be specified at compile-time, thereby giving the Triton compiler enough information to vectorize when beneficial without requiring any additional code modification.
#### <span style="color:purple"> Auto-Tuning </span> <a name="auto-tuning"></a>
As it turns out, different input vector lengths `N` may require different values of `SIZE` to perform optimally. Fortunately, the Triton preprocessor also accepts lists of possible definitions for macros, in which case an auto-tuning procedure will be launched every-time new input sizes are encountered. For example, compiling the above kernel with the option`-DSIZE=[32, 64, 128, 256] -DTYPE=float`
will result in the parameter `SIZE` being automatically tuned every time a new value of `N` is encountered.
_Note: Tuning our reference CUDA kernel would be much more cumbersome, as template metaprogramming would have to be used to ensure that proper vector types would be used_
## <span style="color:darkred"> Matrix Transposition </span> <a name="matrix-transposition"></a>
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from _uncoalesced_ memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown [here](https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/), but this comes at the cost of simplicity and -- more importantly -- interferes with auto-tuning.
### <span style="color:darkblue"> Compute Kernel </span> <a name="trans-compute-kernel"></a>
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory in order to enable shared memory stores/loads. Therefore, an optimal Triton kernel for this operation would look like:
```c
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
// write back using the transposition operator '^'
*py = ^(*px); //(7)
}
```
At a high level, this kernel loads a `TM x TN` tile from the input matrix `X`, transposes it and writes the resulting `TN x TM` tile to the output matrix `Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of `(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
```
pidm = 2
pidn = 1
```
- Statements (3) and (4) construct the ranges of indices:
```
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
```
which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
```
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
```
- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
```
│ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
│ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │
```
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
### <span style="color:darkblue"> The __multipleof Attribute </span> <a name="trans-multipleof"></a>
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
```c
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) {
// ...
}
```
### <span style="color:darkblue"> Conditional Dereferencing </span> <a name="conditional-dereferencing"></a>
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
```c
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
// ...
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
```
Here, statements (7a) creates an array of booleans `checkx[TM, TN]` such that `checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator `*?(predicate) pointer`.
## <span style="color:darkred"> Matrix Multiplication </span> <a name="matrix-multiplication"></a>
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see [CUTLASS](https://github.com/NVIDIA/cutlass)). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
_Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops._
### <span style="color:darkblue"> Compute Kernel </span> <a name="matmul-compute-kernel"></a>
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
```c
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0); //(1)
int pn = get_program_id(1); //(2)
int rm[TM] = pm * TM + 0 ... TM; //(3)
int rn[TN] = pn * TN + 0 ... TN; //(4)
int rk[TK] = 0 ... TK; //(5)
// initialize accumulator
float c[TM, TN] = 0; //(6)
// pointers to operands
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
// reduction loop
for(int k = K; k > 0; k-= TK){
// fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
}
// epilogue
TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
*pc = c; //(15)
}
```
Here, each kernel instance produces a `TM x TN` tile of the output matrix C as follows:
- Statements (1) - (2) fetch the id of the current program instance.
- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix `C`
- Statement (5) constructs a range of indices along the reduction axis: `rk = [0, 1, ..., TK - 1]`
- Statement (6) initialize a `TM x TN` array of accumulators to hold the result of `A[rm, :] x B[:, rn]`
- Statements (7) - (8) initializes arrays of pointers `pa` and `pb` to the operands `A` and `B` using logic similar to that of the above transposition kernel
- Statements (9) - (10) load tiles of operands by dereferencing `pa` and `pb`
- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator '@'
- Statements (12) - (13) updates `pa` and `pb`
- Statement (14) creates an array of pointers `pc` to the result matrix `C`
- Statement (15) writes back the accumulator to `C`
Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
- Automatic coalescing of load/store operations
- Automatic vectorization of load/store operations
- Stashing `a` and `b` to shared memory
- Automatic allocation of shared memory
- Automatic synchronization of shared memory
- Automatic padding of shared memory to avoid bank conflicts
- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
### <span style="color:darkblue"> Optimizations </span> <a name="optimizations"></a>
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
#### <span style="color:purple"> Pre-Fetching </span> <a name="pre-fetching"></a>
The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
```
// pre-fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
bool check = k > TK;
// pre-fetch operands
a = check ? *pa : 0;
b = check ? *pb : 0;
}
```
Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
#### <span style="color:purple"> Rematerialization </span> <a name="rematerialization"></a>
[Rematerialization](https://en.wikipedia.org/wiki/Rematerialization) is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only `rm` and `rn` need to be rematerialized, leading to the following epilogue:
```c
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
```
### <span style="color:darkblue"> Fused Transpositions and Auto-Tuning </span> <a name="fused-trans-autotuning"></a>
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. This is also what is done in the [PyTriton](https://github.com/ptillet/triton/blob/master/python/triton/ops/dot.py) implementation of matrix-multiplication. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
```c
// Triton-C
// launched on a grid of (M / TM) x (N / TN) programs
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
// prologue
int pm = get_program_id(0);
int pn = get_program_id(1);
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = (*pa);
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
}
```
All matrix multiplications variants can then be retrieved using the following compilation option:
```c
// A is not transposed
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda -DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
// A is transposed
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1 -DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
// B is not transpose
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1 -DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
// B is transpose
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
```
Auto-tuning can also be handled using pre-processor macros:
```c
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
```

View File

@@ -0,0 +1,77 @@
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class align {
private:
struct cst_info {
unsigned num_cst;
unsigned value;
};
// helpers
std::vector<unsigned> get_shapes(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};
}
}
}
#endif

View File

@@ -0,0 +1,47 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
namespace analysis{
class tiles;
class liveness;
class cts;
class allocation {
public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;
};
}
}
}
#endif

View File

@@ -0,0 +1,51 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include "triton/tools/graph.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
private:
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:
axes();
void run(ir::module &mod);
// accessors
int get(ir::value *value, unsigned dim);
std::vector<int> get(ir::value *value);
private:
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}
}
}
#endif

View File

@@ -0,0 +1,205 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma884_layout;
class scanline_layout;
class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
HMMA_884,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
size_t find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class mma884_layout: public data_layout {
public:
mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
private:
std::vector<int> fpw_;
std::vector<int> wpt_;
};
struct scanline_layout: public data_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
private:
std::vector<int> mts_;
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
public:
shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
};
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
size_t tmp(ir::instruction* i) { return tmp_.at((ir::value*)i);}
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
size_t num_warps_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
};
}
}
}
#endif

View File

@@ -0,0 +1,67 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
namespace triton{
namespace ir{
class value;
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
typedef unsigned slot_index;
class tiles;
class layouts;
class data_layout;
struct segment {
slot_index start;
slot_index end;
bool contains(slot_index idx) const {
return start <= idx && idx < end;
}
bool intersect(const segment &Other){
return contains(Other.start) || Other.contains(start);
}
};
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
// run
void run(ir::module &mod);
private:
// analysis
layouts *layouts_;
intervals_map_t intervals_;
};
}
}
}
#endif

View File

@@ -0,0 +1,30 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <list>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class pass {
public:
virtual void run(ir::module& m);
};
class pass_manager {
public:
void add(pass* p);
void run(ir::module& m);
private:
std::list<pass*> passes;
};
}
}

View File

@@ -0,0 +1,171 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/selection/machine_value.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
// forward
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
target *tgt,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit_layout_hmma_884(analysis::mma884_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
void visit(ir::module &, llvm::Module &);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *sh_mem_ptr_;
unsigned num_warps_;
std::set<ir::value*> seen_;
};
}
}
#endif

View File

@@ -0,0 +1,138 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
#include <map>
#include "triton/codegen/analysis/layout.h"
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class value;
}
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class machine_data_layout {
public:
virtual tile* create(ir::value *v) = 0;
};
class machine_shared_layout: public machine_data_layout {
public:
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::shared_layout* layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::allocation* alloc_;
Value *&sh_mem_ptr_;
analysis::shared_layout* layout_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
Value *offset_;
Value *ptr_;
Value *pre_ptr_;
Value *next_ptr_;
};
class machine_distributed_layout: public machine_data_layout {
public:
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout* layout);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::data_layout* layout_;
};
class machine_mma884_layout: public machine_distributed_layout {
public:
machine_mma884_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
unsigned pack_size_1_;
unsigned num_packs_0_;
unsigned num_packs_1_;
};
class machine_scanline_layout: public machine_distributed_layout {
public:
machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::scanline_layout* layout);
};
}
}
#endif

View File

@@ -0,0 +1,152 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
#define _TRITON_SELECTION_MACHINE_VALUE_H_
#include <vector>
#include <map>
#include <functional>
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
typedef std::vector<Value*> indices_t;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class tile {
protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, Value *v) = 0;
virtual Value* get_value(indices_t idx) = 0;
Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
Type *ty_;
shapes_t shapes_;
};
class shared_tile: public tile {
private:
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
public:
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, Value *);
Value* get_ptr_to(indices_t idx);
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
const std::vector<int>& get_perm() { return perm_; }
const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
private:
Value *ptr_;
bool return_vector_;
Builder &builder_;
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
std::vector<int> order_;
std::vector<int> perm_;
};
// Distribtued tile
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::map<indices_t, Value*> values_map_t;
private:
void init_indices();
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
private:
axes_t axes_;
std::vector<int> order_;
indices_map_t indices_;
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
Builder &builder_;
};
}
}
#endif

View File

@@ -0,0 +1,98 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
bool is_gpu() const;
private:
bool is_gpu_;
};
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}
}
#endif

View File

@@ -0,0 +1,47 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class io_inst;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class align;
class layouts;
class cts;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layouts* layout_;
};
}
}
}
#endif

View File

@@ -0,0 +1,28 @@
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class module;
class value;
class phi_node;
class instruction;
}
namespace codegen{
namespace transform{
class cts {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -0,0 +1,24 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class dce {
public:
dce() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -0,0 +1,22 @@
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class disassociate {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -0,0 +1,59 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
namespace triton {
namespace ir {
class module;
class basic_block;
class instruction;
class value;
class builder;
}
namespace codegen{
namespace analysis{
class allocation;
class liveness;
class layouts;
class cts;
}
namespace transform{
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
};
}
}
}
#endif

View File

@@ -0,0 +1,42 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
namespace transform{
class peephole {
private:
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
private:
public:
peephole() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -0,0 +1,49 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H
#include <map>
#include <set>
#include <vector>
namespace triton {
// forward declaration
namespace ir {
class module;
class value;
class builder;
class instruction;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class tiles;
class align;
}
namespace transform{
class reassociate {
struct cst_info {
ir::value* dyn_ptr;
ir::getelementptr_inst* sta_ptr;
};
private:
ir::instruction* is_bin_add(ir::value *x);
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst);
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
void run(ir::module& module);
};
}
}
}
#endif

137
include/triton/driver/backend.h Executable file
View File

@@ -0,0 +1,137 @@
#pragma once
#ifndef _TRITON_DRIVER_BACKEND_H_
#define _TRITON_DRIVER_BACKEND_H_
#include <map>
#include <list>
#include <vector>
#include "triton/driver/context.h"
namespace llvm
{
class Module;
}
namespace triton
{
namespace driver
{
class buffer;
class stream;
class device;
class context;
class platform;
class module;
class kernel;
struct backend
{
// platforms
class platforms
{
friend class backend;
private:
static void init();
public:
static void get(std::vector<driver::platform*> &results);
private:
static std::vector<driver::platform*> cache_;
};
// devices
class devices
{
friend class backend;
private:
static void init(const std::vector<platform *> &platforms);
public:
static void get(std::vector<driver::device*>& devs);
private:
static std::vector<driver::device*> cache_;
};
// modules
class modules
{
friend class backend;
public:
static void release();
private:
static std::map<std::tuple<driver::stream*, std::string>, driver::module*> cache_;
};
// kernels
class kernels
{
friend class backend;
public:
static void release();
static driver::kernel* get(driver::module* mod, const std::string & name);
private:
static std::map<std::tuple<module*, std::string>, driver::kernel*> cache_;
};
// contexts
class contexts
{
friend class backend;
private:
static void init(const std::vector<device *> &);
static void release();
public:
static driver::context* get_default();
static driver::context* import(CUcontext ctx)
{
for(driver::context* x: cache_){
driver::cu_context* cu_x = (driver::cu_context*)x;
if(*cu_x->cu()==ctx)
return x;
}
cache_.emplace_back(new driver::cu_context(ctx, false));
return cache_.back();
}
static void get(std::list<driver::context*> &);
private:
static std::list<driver::context*> cache_;
};
// streams
class streams
{
friend class backend;
private:
static void init(std::list<context*> const &);
static void release();
public:
static void get(driver::context*, std::vector<driver::stream *> &streams);
static driver::stream* get(driver::context*, unsigned int id = 0);
static driver::stream* get_default();
private:
static std::map<driver::context*, std::vector<driver::stream*> > cache_;
};
static void init();
static void release();
static void synchronize(triton::driver::context *);
static unsigned int default_device;
};
}
}
#endif

57
include/triton/driver/buffer.h Executable file
View File

@@ -0,0 +1,57 @@
#pragma once
#ifndef _TRITON_DRIVER_BUFFER_H_
#define _TRITON_DRIVER_BUFFER_H_
#include "triton/driver/handle.h"
#include "triton/driver/context.h"
namespace triton
{
namespace driver
{
class stream;
// Base
class buffer : public polymorphic_resource<CUdeviceptr, cl_mem, host_buffer_t> {
public:
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership);
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
static buffer* create(driver::context* ctx, size_t size);
driver::context* context();
size_t size();
protected:
driver::context* context_;
size_t size_;
};
// CPU
class host_buffer: public buffer
{
public:
host_buffer(driver::context* context, size_t size);
};
// OpenCL
class ocl_buffer: public buffer
{
public:
ocl_buffer(driver::context* context, size_t size);
};
// CUDA
class cu_buffer: public buffer
{
public:
cu_buffer(driver::context* context, size_t size);
cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership);
void set_zero(triton::driver::stream *queue, size_t size);
};
}
}
#endif

70
include/triton/driver/context.h Executable file
View File

@@ -0,0 +1,70 @@
#pragma once
#ifndef _TRITON_DRIVER_CONTEXT_H_
#define _TRITON_DRIVER_CONTEXT_H_
#include "triton/driver/device.h"
#include "triton/driver/handle.h"
namespace triton
{
namespace driver
{
class context: public polymorphic_resource<CUcontext, cl_context, host_context_t>{
protected:
static std::string get_cache_path();
public:
context(driver::device *dev, CUcontext cu, bool take_ownership);
context(driver::device *dev, cl_context cl, bool take_ownership);
context(driver::device *dev, host_context_t hst, bool take_ownership);
driver::device* device() const;
std::string const & cache_path() const;
// factory methods
static context* create(driver::device *dev);
protected:
driver::device* dev_;
std::string cache_path_;
};
// Host
class host_context: public context {
public:
host_context(driver::device* dev);
};
// CUDA
class cu_context: public context {
public:
class context_switcher{
public:
context_switcher(driver::context const & ctx);
~context_switcher();
private:
driver::cu_context const & ctx_;
};
private:
static CUdevice get_device_of(CUcontext);
public:
//Constructors
cu_context(CUcontext cu, bool take_ownership = true);
cu_context(driver::device* dev);
};
// OpenCL
class ocl_context: public context {
public:
ocl_context(driver::device* dev);
};
}
}
#endif

229
include/triton/driver/cublas.h Executable file
View File

@@ -0,0 +1,229 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TDL_INCLUDE_DRIVER_CUBLAS_H
#define TDL_INCLUDE_DRIVER_CUBLAS_H
#include "isaac/templates/common.hpp"
#include "triton/driver/dispatch.h"
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/driver/backend.h"
#include "triton/driver/error.h"
#include "triton/tools/bench.hpp"
#include "triton/tools/collections.hpp"
namespace triton
{
namespace driver
{
enum cublasStrategy_t{
CUBLAS_PREFER_FASTEST,
CUBLAS_HEURISTICS
};
static const std::vector<cublasGemmAlgo_t> cublasAlgorithms = {
CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3,
CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7
};
static const std::map<DType, cudaDataType> cudtype = {{FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE,CUDA_R_64F}};
static const std::map<char, cublasOperation_t> cuop = {{'N', CUBLAS_OP_N}, {'T', CUBLAS_OP_T}};
inline cublasGemmAlgo_t cublasGemmFastest(stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc){
typedef std::tuple<cudaDataType_t, cublasOperation_t, cublasOperation_t, int32_t, int32_t, int32_t> key_t;
// Benchmark fastest algorithm in cublasGemmEx
auto benchmark_fastest = [&](key_t const &){
std::vector<double> times;
for(cublasGemmAlgo_t a: cublasAlgorithms){
try{
times.push_back(bench([&](){ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, a); },
[&](){ stream.synchronize(); },
stream.context().device()));
}catch(driver::exception::cublas::base const &){
times.push_back(INFINITY);
}
}
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
return cublasAlgorithms[argmin];
};
// Cache result
static cpp::CachedMap<key_t, cublasGemmAlgo_t> cache(benchmark_fastest);
return cache.get(std::make_tuple(cudt, AT, BT, M, N, K));
}
/* Wrapper for cublasGemmEx */
inline void cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
{ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); }
/* Simplified API for default GEMM */
inline void cublasGemm(DType dtype, stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, cu_buffer const & A, int32_t lda, cu_buffer const & B, int32_t ldb, scalar beta, cu_buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
ContextSwitcher ctx_switch(stream.context());
cublasHandle_t handle = dispatch::cublasHandle(stream.context());
dispatch::cublasSetStream_v2(handle, (CUstream)stream);
if(fastest)
*fastest = cublasGemmFastest(stream, handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
else
cublasGemmEx(handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc, algo);
}
inline cudnnDataType_t cudnnDtype(DType dtype){
switch(dtype){
case INT8X4_TYPE: return CUDNN_DATA_INT8x4;
case INT32_TYPE: return CUDNN_DATA_INT32;
case FLOAT_TYPE: return CUDNN_DATA_FLOAT;
case DOUBLE_TYPE: return CUDNN_DATA_DOUBLE;
}
throw;
}
inline cudnnTensorFormat_t format(cudnnDataType_t cutype){
switch(cutype){
case CUDNN_DATA_INT8x4: return CUDNN_TENSOR_NCHW_VECT_C;
default: return CUDNN_TENSOR_NCHW;
}
}
inline void cudnnConv(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, cu_buffer const & F, scalar beta, cu_buffer const & O){
driver::driver::context const & ctx = stream.context();
ContextSwitcher switch_ctx(ctx);
std::vector<int> pad = {pad_d, pad_h, pad_w};
std::vector<int> stride = {stride_d, stride_h, stride_w};
std::vector<int> upscale = {1, 1, 1};
std::vector<int> Oshapes = {N, K, M, P, Q};
std::vector<int> Fshapes = {K, C, T, R, S};
std::vector<int> Ishapes = {N, C, D, H, W};
if(M == 1 && T == 1 && D == 1){
pad.erase(pad.begin());
stride.erase(stride.begin());
upscale.erase(upscale.begin());
Oshapes.erase(Oshapes.begin() + 2);
Ishapes.erase(Ishapes.begin() + 2);
Fshapes.erase(Fshapes.begin() + 2);
}
cudnnHandle_t handle = dispatch::cudnnHandle(ctx);
cudnnDataType_t in_cutype = cudnnDtype(dtype);
cudnnDataType_t conv_cutype = (dtype == INT8X4_TYPE)?CUDNN_DATA_INT32:in_cutype;
dispatch::cudnnSetStream(handle, (CUstream)stream);
cudnnTensorDescriptor_t tO, tI;
cudnnFilterDescriptor_t tF;
cudnnConvolutionDescriptor_t conv;
cudnnConvolutionFwdAlgo_t algo;
dispatch::cudnnCreateTensorDescriptor(&tO);
dispatch::cudnnCreateTensorDescriptor(&tI);
dispatch::cudnnCreateFilterDescriptor(&tF);
dispatch::cudnnSetTensorNdDescriptorEx(tO, format(in_cutype), in_cutype, Oshapes.size(), Oshapes.data());
dispatch::cudnnSetFilterNdDescriptor(tF, in_cutype, format(in_cutype), Fshapes.size(), Fshapes.data());
dispatch::cudnnSetTensorNdDescriptorEx(tI, format(in_cutype), in_cutype, Ishapes.size(), Ishapes.data());
dispatch::cudnnCreateConvolutionDescriptor(&conv);
dispatch::cudnnSetConvolutionNdDescriptor(conv, pad.size(), pad.data(), stride.data(), upscale.data(), CUDNN_CROSS_CORRELATION, conv_cutype);
dispatch::cudnnGetConvolutionForwardAlgorithm(handle, tI, tF, conv, tO, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 1024*1024*64, &algo);
size_t workspace_size;
dispatch::cudnnGetConvolutionForwardWorkspaceSize(handle, tI, tF, conv, tO, algo, &workspace_size);
static cu_buffer work(ctx, 1024*1024*64);
CUdeviceptr twork = work;
CUdeviceptr pI = I, pF = F, pO = O;
dispatch::cudnnConvolutionForward(handle, alpha.data(), tI, (void*)pI, tF, (void*)pF, conv, algo, (void*)twork, workspace_size, beta.data(), tO, (void*)pO);
}
inline void cudnnPool(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, scalar beta, cu_buffer const & O){
driver::driver::context const & ctx = stream.context();
ContextSwitcher switch_ctx(ctx);
std::vector<int> pad = {pad_d, pad_h, pad_w};
std::vector<int> stride = {stride_d, stride_h, stride_w};
std::vector<int> upscale = {1, 1, 1};
std::vector<int> Oshapes = {N, K, M, P, Q};
std::vector<int> Ishapes = {N, K, D, H, W};
std::vector<int> window = {T, R, S};
if(M == 1 && T == 1 && D == 1){
window.erase(window.begin());
pad.erase(pad.begin());
stride.erase(stride.begin());
upscale.erase(upscale.begin());
Oshapes.erase(Oshapes.begin() + 2);
Ishapes.erase(Ishapes.begin() + 2);
}
cudnnHandle_t handle = dispatch::cudnnHandle(ctx);
cudnnDataType_t cutype = cudnnDtype(dtype);
dispatch::cudnnSetStream(handle, (CUstream)stream);
cudnnTensorDescriptor_t tO, tI;
cudnnPoolingDescriptor_t desc;
dispatch::cudnnCreateTensorDescriptor(&tO);
dispatch::cudnnCreateTensorDescriptor(&tI);
dispatch::cudnnSetTensorNdDescriptorEx(tO, CUDNN_TENSOR_NCHW, cutype, Oshapes.size(), Oshapes.data());
dispatch::cudnnSetTensorNdDescriptorEx(tI, CUDNN_TENSOR_NCHW, cutype, Ishapes.size(), Ishapes.data());
dispatch::cudnnCreatePoolingDescriptor(&desc);
dispatch::cudnnSetPoolingNdDescriptor(desc, CUDNN_POOLING_MAX, CUDNN_NOT_PROPAGATE_NAN, window.size(), window.data(), pad.data(), stride.data());
CUdeviceptr pI = I, pO = O;
dispatch::cudnnPoolingForward(handle, desc, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
}
inline void cudnnTransformTensor(driver::cu_stream & stream,
DType in_dtype, DType out_dtype,
cudnnTensorFormat_t in_layout, cudnnTensorFormat_t out_layout,
int32_t N, int32_t C, int32_t D, int32_t H, int32_t W,
scalar alpha, driver::cu_buffer const & I, scalar beta, driver::cu_buffer& O)
{
cudnnHandle_t handle = dispatch::cudnnHandle(stream.context());
dispatch::cudnnSetStream(handle, (CUstream)stream);
cudnnTensorDescriptor_t tO, tI;
std::vector<int> shapes = {N, C, D, H, W};
dispatch::cudnnCreateTensorDescriptor(&tI);
dispatch::cudnnSetTensorNdDescriptorEx(tI, in_layout, cudnnDtype(in_dtype), shapes.size(), shapes.data());
dispatch::cudnnCreateTensorDescriptor(&tO);
dispatch::cudnnSetTensorNdDescriptorEx(tO, out_layout, cudnnDtype(out_dtype), shapes.size(), shapes.data());
CUdeviceptr pI = I, pO = O;
dispatch::cudnnTransformTensor(handle, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
}
}
}
#endif

110
include/triton/driver/device.h Executable file
View File

@@ -0,0 +1,110 @@
#pragma once
#ifndef _TRITON_DRIVER_DEVICE_H_
#define _TRITON_DRIVER_DEVICE_H_
#include "triton/driver/platform.h"
#include "triton/driver/handle.h"
namespace triton
{
namespace codegen
{
class target;
}
namespace driver
{
class context;
// Base device
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
public:
using polymorphic_resource::polymorphic_resource;
virtual size_t max_threads_per_block() const = 0;
virtual size_t max_shared_memory() const = 0;
virtual std::unique_ptr<codegen::target> make_target() const = 0;
};
// Host device
class host_device: public device {
public:
host_device(): device(host_device_t(), true){ }
size_t max_threads_per_block() const { return 1; }
size_t max_shared_memory() const { return 0; }
std::unique_ptr<codegen::target> make_target() const;
};
// OpenCL device
class ocl_device: public device {
public:
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
std::unique_ptr<codegen::target> make_target() const;
};
// CUDA device
class cu_device: public device {
public:
//Supported architectures
enum class Architecture{
//NVidia
SM_2_0,
SM_2_1,
SM_3_0,
SM_3_5,
SM_3_7,
SM_5_0,
SM_5_2,
SM_6_0,
SM_6_1,
SM_7_0,
UNKNOWN
};
private:
//Metaprogramming elper to get cuda info from attribute
template<CUdevice_attribute attr>
int cuGetInfo() const;
inline Architecture nv_arch(std::pair<unsigned int, unsigned int> sm) const;
inline nvmlDevice_t nvml_device() const;
public:
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
// Accessors
Architecture architecture() const;
// Informations
std::string infos() const;
size_t address_bits() const;
std::vector<size_t> max_block_dim() const;
size_t warp_size() const;
// Compute Capability
void interpret_as(std::pair<size_t, size_t> cc);
std::pair<size_t, size_t> compute_capability() const;
// Identifier
std::string name() const;
std::string pci_bus_id() const;
// Clocks
size_t current_sm_clock() const;
size_t current_mem_clock() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
size_t max_sm_clock() const;
size_t max_mem_clock() const;
void set_max_clock();
// Target
std::unique_ptr<codegen::target> make_target() const;
private:
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
};
}
}
#endif

259
include/triton/driver/dispatch.h Executable file
View File

@@ -0,0 +1,259 @@
#pragma once
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <type_traits>
#include <dlfcn.h>
//CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
#include "triton/external/CL/cl.h"
#include "triton/external/CL/cl_ext.h"
//Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
}
namespace triton
{
namespace driver
{
class cu_context;
template<class T> void check(T){}
void check(CUresult err);
void check(cl_int err);
class dispatch
{
protected:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
check(res);
return res;
}
public:
static bool clinit();
static bool nvmlinit();
static bool cuinit();
static bool spvllvminit();
static void release();
// OpenCL
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);
static cl_int clReleaseMemObject(cl_mem);
static cl_int clFinish(cl_command_queue);
static cl_int clGetMemObjectInfo(cl_mem, cl_mem_info, size_t, void *, size_t *);
static cl_int clGetCommandQueueInfo(cl_command_queue, cl_command_queue_info, size_t, void *, size_t *);
static cl_int clReleaseContext(cl_context);
static cl_int clReleaseEvent(cl_event);
static cl_int clEnqueueWriteBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, cl_uint, const cl_event *, cl_event *);
static cl_int clEnqueueReadBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, cl_uint, const cl_event *, cl_event *);
static cl_int clGetProgramBuildInfo(cl_program, cl_device_id, cl_program_build_info, size_t, void *, size_t *);
static cl_int clReleaseDevice(cl_device_id);
static cl_context clCreateContext(const cl_context_properties *, cl_uint, const cl_device_id *, void (*)(const char *, const void *, size_t, void *), void *, cl_int *);
static cl_int clGetDeviceIDs(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *);
static cl_int clGetContextInfo(cl_context, cl_context_info, size_t, void *, size_t *);
static cl_int clGetDeviceInfo(cl_device_id, cl_device_info, size_t, void *, size_t *);
static cl_int clReleaseCommandQueue(cl_command_queue);
static cl_int clGetPlatformIDs(cl_uint, cl_platform_id *, cl_uint *);
static cl_int clGetPlatformInfo(cl_platform_id, cl_platform_info, size_t, void *, size_t *);
static cl_int clGetEventProfilingInfo(cl_event, cl_profiling_info, size_t, void *, size_t *);
static cl_program clCreateProgramWithBinary(cl_context, cl_uint, const cl_device_id *, const size_t *, const unsigned char **, cl_int *, cl_int *);
static cl_command_queue clCreateCommandQueue(cl_context, cl_device_id, cl_command_queue_properties, cl_int *);
static cl_int clRetainEvent(cl_event);
static cl_int clReleaseProgram(cl_program);
static cl_int clFlush(cl_command_queue);
static cl_int clGetProgramInfo(cl_program, cl_program_info, size_t, void *, size_t *);
static cl_int clGetKernelInfo(cl_kernel, cl_kernel_info, size_t, void *, size_t *);
static cl_int clGetKernelWorkGroupInfo(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
static cl_kernel clCreateKernel(cl_program, const char *, cl_int *);
static cl_int clCreateKernelsInProgram(cl_program, cl_uint, cl_kernel*, cl_uint*);
static cl_mem clCreateBuffer(cl_context, cl_mem_flags, size_t, void *, cl_int *);
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
static cl_int clReleaseKernel(cl_kernel);
// CUDA
static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuDriverGetVersion(int *driverVersion);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
static CUresult cuInit(unsigned int Flags);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config);
// NVML
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
private:
// Libraries
static void* opencl_;
static void* cuda_;
static void* nvml_;
static void* vulkan_;
static void* spvllvm_;
static void* spvcross_;
static void* opengl_;
// OpenCL functions
static void* clBuildProgram_;
static void* clEnqueueNDRangeKernel_;
static void* clSetKernelArg_;
static void* clReleaseMemObject_;
static void* clFinish_;
static void* clGetMemObjectInfo_;
static void* clGetCommandQueueInfo_;
static void* clReleaseContext_;
static void* clReleaseEvent_;
static void* clEnqueueWriteBuffer_;
static void* clEnqueueReadBuffer_;
static void* clGetProgramBuildInfo_;
static void* clReleaseDevice_;
static void* clCreateContext_;
static void* clGetDeviceIDs_;
static void* clGetContextInfo_;
static void* clGetDeviceInfo_;
static void* clReleaseCommandQueue_;
static void* clGetPlatformIDs_;
static void* clGetPlatformInfo_;
static void* clGetEventProfilingInfo_;
static void* clCreateProgramWithBinary_;
static void* clCreateCommandQueue_;
static void* clRetainEvent_;
static void* clReleaseProgram_;
static void* clFlush_;
static void* clGetProgramInfo_;
static void* clGetKernelInfo_;
static void* clGetKernelWorkGroupInfo_;
static void* clCreateKernel_;
static void* clCreateKernelsInProgram_;
static void* clCreateBuffer_;
static void* clCreateProgramWithSource_;
static void* clReleaseKernel_;
// CUDA functions
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuEventCreate_;
static void* cuDeviceGet_;
static void* cuMemcpyDtoH_v2_;
static void* cuStreamCreate_;
static void* cuEventElapsedTime_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuDriverGetVersion_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuModuleGetGlobal_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuModuleLoad_;
static void* cuLaunchKernel_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
static void* cuMemcpyHtoD_v2_;
static void* cuInit_;
static void* cuEventRecord_;
static void* cuCtxCreate_v2_;
static void* cuModuleGetFunction_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuEventDestroy_v2_;
static void* cuMemAlloc_v2_;
static void* cuPointerGetAttribute_;
static void* cuCtxGetDevice_;
static void* cuMemsetD8Async_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// NVML
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
// LLVM to SPIR-V
static void* initializeLLVMToSPIRVPass_;
static void* writeSpirv_;
};
}
}
#endif

208
include/triton/driver/error.h Executable file
View File

@@ -0,0 +1,208 @@
#pragma once
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include <exception>
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace exception
{
namespace nvrtc
{
#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef ISAAC_CREATE_NVRTC_EXCEPTION
}
namespace cuda
{
class base: public std::exception{};
#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found");
ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUDA_EXCEPTION
}
namespace cublas
{
class base: public std::exception{};
#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUBLAS_EXCEPTION
}
namespace cudnn
{
#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
namespace ocl
{
class base: public std::exception{};
#define ISAAC_CREATE_CL_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "OpenCL: Error- " msg; } }
ISAAC_CREATE_CL_EXCEPTION(device_not_found, "device not found");
ISAAC_CREATE_CL_EXCEPTION(device_not_available, "device not available");
ISAAC_CREATE_CL_EXCEPTION(compiler_not_available, "compiler not available");
ISAAC_CREATE_CL_EXCEPTION(mem_object_allocation_failure, "object allocation failure");
ISAAC_CREATE_CL_EXCEPTION(out_of_resources, "launch out of resources");
ISAAC_CREATE_CL_EXCEPTION(out_of_host_memory, "out of host memory");
ISAAC_CREATE_CL_EXCEPTION(profiling_info_not_available, "profiling info not available");
ISAAC_CREATE_CL_EXCEPTION(mem_copy_overlap, "mem copy overlap");
ISAAC_CREATE_CL_EXCEPTION(image_format_mismatch, "image format mismatch");
ISAAC_CREATE_CL_EXCEPTION(image_format_not_supported, "image format not supported");
ISAAC_CREATE_CL_EXCEPTION(build_program_failure, "build program failure");
ISAAC_CREATE_CL_EXCEPTION(map_failure, "map failure");
ISAAC_CREATE_CL_EXCEPTION(invalid_value, "invalid value");
ISAAC_CREATE_CL_EXCEPTION(invalid_device_type, "invalid device type");
ISAAC_CREATE_CL_EXCEPTION(invalid_platform, "invalid platform");
ISAAC_CREATE_CL_EXCEPTION(invalid_device, "invalid device");
ISAAC_CREATE_CL_EXCEPTION(invalid_context, "invalid context");
ISAAC_CREATE_CL_EXCEPTION(invalid_queue_properties, "invalid queue properties");
ISAAC_CREATE_CL_EXCEPTION(invalid_command_queue, "invalid command queue");
ISAAC_CREATE_CL_EXCEPTION(invalid_host_ptr, "invalid host pointer");
ISAAC_CREATE_CL_EXCEPTION(invalid_mem_object, "invalid mem object");
ISAAC_CREATE_CL_EXCEPTION(invalid_image_format_descriptor, "invalid image format descriptor");
ISAAC_CREATE_CL_EXCEPTION(invalid_image_size, "invalid image size");
ISAAC_CREATE_CL_EXCEPTION(invalid_sampler, "invalid sampler");
ISAAC_CREATE_CL_EXCEPTION(invalid_binary, "invalid binary");
ISAAC_CREATE_CL_EXCEPTION(invalid_build_options, "invalid build options");
ISAAC_CREATE_CL_EXCEPTION(invalid_program, "invalid program");
ISAAC_CREATE_CL_EXCEPTION(invalid_program_executable, "invalid program executable");
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_name, "invalid kernel name");
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_definition, "invalid kernel definition");
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel, "invalid kernel");
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_index, "invalid arg index");
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_value, "invalid arg value");
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_size, "invalid arg size");
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_args, "invalid kernel args");
ISAAC_CREATE_CL_EXCEPTION(invalid_work_dimension, "invalid work dimension");
ISAAC_CREATE_CL_EXCEPTION(invalid_work_group_size, "invalid work group size");
ISAAC_CREATE_CL_EXCEPTION(invalid_work_item_size, "invalid work item size");
ISAAC_CREATE_CL_EXCEPTION(invalid_global_offset, "invalid global offset");
ISAAC_CREATE_CL_EXCEPTION(invalid_event_wait_list, "invalid event wait list");
ISAAC_CREATE_CL_EXCEPTION(invalid_event, "invalid event");
ISAAC_CREATE_CL_EXCEPTION(invalid_operation, "invalid operation");
ISAAC_CREATE_CL_EXCEPTION(invalid_gl_object, "invalid GL object");
ISAAC_CREATE_CL_EXCEPTION(invalid_buffer_size, "invalid buffer size");
ISAAC_CREATE_CL_EXCEPTION(invalid_mip_level, "invalid MIP level");
ISAAC_CREATE_CL_EXCEPTION(invalid_global_work_size, "invalid global work size");
#ifdef CL_INVALID_PROPERTY
ISAAC_CREATE_CL_EXCEPTION(invalid_property, "invalid property");
#endif
}
}
}
}
#endif

29
include/triton/driver/event.h Executable file
View File

@@ -0,0 +1,29 @@
#pragma once
#ifndef _TRITON_DRIVER_EVENT_H_
#define _TRITON_DRIVER_EVENT_H_
#include "triton/driver/handle.h"
namespace triton
{
namespace driver
{
// event
class event
{
public:
float elapsed_time() const;
handle<cu_event_t> const & cu() const;
private:
handle<cu_event_t> cu_;
};
}
}
#endif

132
include/triton/driver/handle.h Executable file
View File

@@ -0,0 +1,132 @@
#pragma once
#ifndef _TRITON_DRIVER_HANDLE_H_
#define _TRITON_DRIVER_HANDLE_H_
#include <memory>
#include <map>
#include <iostream>
#include <functional>
#include <type_traits>
#include "triton/driver/dispatch.h"
namespace llvm
{
class ExecutionEngine;
class Function;
}
namespace triton
{
namespace driver
{
enum backend_t {
CUDA,
OpenCL,
Host
};
// Host handles
struct host_platform_t{
};
struct host_device_t{
};
struct host_context_t{
};
struct host_stream_t{
};
struct host_module_t{
std::string error;
llvm::ExecutionEngine* engine;
std::map<std::string, llvm::Function*> functions;
};
struct host_function_t{
llvm::Function* fn;
};
struct host_buffer_t{
char* data;
};
// Extra CUDA handles
struct cu_event_t{
operator bool() const { return first && second; }
CUevent first;
CUevent second;
};
struct CUPlatform{
CUPlatform() : status_(dispatch::cuInit(0)) { }
operator bool() const { return status_; }
private:
CUresult status_;
};
template<class T, class CUType>
class handle_interface{
public:
//Accessors
operator CUType() const { return *(((T*)this)->cu().h_); }
//Comparison
bool operator==(handle_interface const & y) { return (CUType)(*this) == (CUType)(y); }
bool operator!=(handle_interface const & y) { return (CUType)(*this) != (CUType)(y); }
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
};
template<class T>
class handle{
public:
template<class, class> friend class handle_interface;
public:
//Constructors
handle(T h, bool take_ownership = true);
handle();
~handle();
T& operator*() { return *h_; }
T const & operator*() const { return *h_; }
T* operator->() const { return h_.get(); }
protected:
std::shared_ptr<T> h_;
bool has_ownership_;
};
template<class CUType, class CLType, class HostType>
class polymorphic_resource {
public:
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership), backend_(OpenCL){}
polymorphic_resource(HostType hst, bool take_ownership): hst_(hst, take_ownership), backend_(Host){}
virtual ~polymorphic_resource() { }
handle<CUType> cu() { return cu_; }
handle<CLType> cl() { return cl_; }
handle<HostType> hst() { return hst_; }
const handle<CUType>& cu() const { return cu_; }
const handle<CLType>& cl() const { return cl_; }
const handle<HostType>& hst() const { return hst_; }
backend_t backend() { return backend_; }
protected:
handle<CLType> cl_;
handle<CUType> cu_;
handle<HostType> hst_;
backend_t backend_;
};
}
}
#endif

88
include/triton/driver/kernel.h Executable file
View File

@@ -0,0 +1,88 @@
#pragma once
#ifndef _TRITON_DRIVER_KERNEL_H_
#define _TRITON_DRIVER_KERNEL_H_
#include "triton/driver/module.h"
#include "triton/driver/handle.h"
#include <memory>
namespace llvm
{
class GenericValue;
}
namespace triton
{
namespace driver
{
class cu_buffer;
// Base
class kernel: public polymorphic_resource<CUfunction, cl_kernel, host_function_t> {
public:
kernel(driver::module* program, CUfunction fn, bool has_ownership);
kernel(driver::module* program, cl_kernel fn, bool has_ownership);
kernel(driver::module* program, host_function_t fn, bool has_ownership);
// Getters
driver::module* module();
// Factory methods
static kernel* create(driver::module* program, const char* name);
// Arguments setters
virtual void setArg(unsigned int index, std::size_t size, void* ptr) = 0;
virtual void setArg(unsigned int index, buffer *) = 0;
template<class T> void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); }
private:
driver::module* program_;
};
// Host
class host_kernel: public kernel {
public:
//Constructors
host_kernel(driver::module* program, const char* name);
// Arguments setters
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, driver::buffer* buffer);
// Params
const std::vector<void*>& params();
private:
std::vector<std::shared_ptr<void> > params_store_;
std::vector<void*> params_;
};
// OpenCL
class ocl_kernel: public kernel {
public:
//Constructors
ocl_kernel(driver::module* program, const char* name);
// Arguments setters
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, driver::buffer* buffer);
};
// CUDA
class cu_kernel: public kernel {
public:
//Constructors
cu_kernel(driver::module* program, const char * name);
// Arguments setters
void setArg(unsigned int index, std::size_t size, void* ptr);
void setArg(unsigned int index, driver::buffer* buffer);
//Arguments getters
void* const* cu_params() const;
private:
std::vector<std::shared_ptr<void> > cu_params_store_;
std::vector<void*> cu_params_;
};
}
}
#endif

87
include/triton/driver/module.h Executable file
View File

@@ -0,0 +1,87 @@
#pragma once
#ifndef _TRITON_DRIVER_MODULE_H_
#define _TRITON_DRIVER_MODULE_H_
#include <map>
#include "triton/driver/handle.h"
#include "triton/driver/context.h"
#include "triton/driver/buffer.h"
namespace llvm
{
class Module;
template<class T>
class SmallVectorImpl;
}
namespace triton
{
namespace driver
{
class cu_context;
class cu_device;
// Base
class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
protected:
void init_llvm();
enum file_type_t{
Object,
Assembly
};
public:
module(driver::context* ctx, CUmodule mod, bool has_ownership);
module(driver::context* ctx, cl_program mod, bool has_ownership);
module(driver::context* ctx, host_module_t mod, bool has_ownership);
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
driver::context* context() const;
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
protected:
driver::context* ctx_;
};
// CPU
class host_module: public module{
public:
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
// OpenCL
class ocl_module: public module{
public:
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
// CUDA
class cu_module: public module {
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
public:
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
cu_module(driver::context* context, const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const;
private:
std::string source_;
};
}
}
#endif

View File

@@ -0,0 +1,70 @@
#pragma once
#ifndef _TRITON_DRIVER_PLATFORM_H_
#define _TRITON_DRIVER_PLATFORM_H_
#include <vector>
#include <string>
#include "triton/driver/handle.h"
namespace triton
{
namespace driver
{
class device;
class platform
{
public:
// Constructor
platform(const std::string& name): name_(name){ }
// Accessors
std::string name() const { return name_; }
// Virtual methods
virtual std::string version() const = 0;
virtual void devices(std::vector<driver::device *> &devices) const = 0;
private:
std::string name_;
};
// CUDA
class cu_platform: public platform
{
public:
cu_platform(): platform("CUDA") { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
private:
handle<CUPlatform> cu_;
};
// OpenCL
class cl_platform: public platform
{
public:
cl_platform(cl_platform_id cl): platform("OpenCL"), cl_(cl) { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
private:
handle<cl_platform_id> cl_;
};
// Host
class host_platform: public platform
{
public:
host_platform(): platform("CPU") { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
};
}
}
#endif

93
include/triton/driver/stream.h Executable file
View File

@@ -0,0 +1,93 @@
#pragma once
#ifndef _TRITON_DRIVER_STREAM_H_
#define _TRITON_DRIVER_STREAM_H_
#include <map>
#include "triton/driver/context.h"
#include "triton/driver/device.h"
#include "triton/driver/handle.h"
#include "triton/driver/buffer.h"
namespace triton
{
namespace driver
{
class kernel;
class event;
class Range;
class cu_buffer;
// Base
class stream: public polymorphic_resource<CUstream, cl_command_queue, host_stream_t> {
public:
stream(driver::context *ctx, CUstream, bool has_ownership);
stream(driver::context *ctx, cl_command_queue, bool has_ownership);
stream(driver::context *ctx, host_stream_t, bool has_ownership);
// factory
static driver::stream* create(driver::context* ctx);
// accessors
driver::context* context() const;
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
template<class T> void write(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T> const & x)
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
protected:
driver::context *ctx_;
};
// Host
class host_stream: public stream {
public:
// Constructors
host_stream(driver::context *ctx);
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
// OpenCL
class cl_stream: public stream {
public:
// Constructors
cl_stream(driver::context *ctx);
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
// CUDA
class cu_stream: public stream {
public:
// Constructors
cu_stream(CUstream str, bool take_ownership);
cu_stream(driver::context* context);
// Overridden
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
}
}
#endif

1468
include/triton/external/CL/cl.h vendored Normal file

File diff suppressed because it is too large Load Diff

12947
include/triton/external/CL/cl.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

9677
include/triton/external/CL/cl2.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

131
include/triton/external/CL/cl_d3d10.h vendored Normal file
View File

@@ -0,0 +1,131 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_D3D10_H
#define __OPENCL_CL_D3D10_H
#include <d3d10.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************
* cl_khr_d3d10_sharing */
#define cl_khr_d3d10_sharing 1
typedef cl_uint cl_d3d10_device_source_khr;
typedef cl_uint cl_d3d10_device_set_khr;
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_D3D10_DEVICE_KHR -1002
#define CL_INVALID_D3D10_RESOURCE_KHR -1003
#define CL_D3D10_RESOURCE_ALREADY_ACQUIRED_KHR -1004
#define CL_D3D10_RESOURCE_NOT_ACQUIRED_KHR -1005
/* cl_d3d10_device_source_nv */
#define CL_D3D10_DEVICE_KHR 0x4010
#define CL_D3D10_DXGI_ADAPTER_KHR 0x4011
/* cl_d3d10_device_set_nv */
#define CL_PREFERRED_DEVICES_FOR_D3D10_KHR 0x4012
#define CL_ALL_DEVICES_FOR_D3D10_KHR 0x4013
/* cl_context_info */
#define CL_CONTEXT_D3D10_DEVICE_KHR 0x4014
#define CL_CONTEXT_D3D10_PREFER_SHARED_RESOURCES_KHR 0x402C
/* cl_mem_info */
#define CL_MEM_D3D10_RESOURCE_KHR 0x4015
/* cl_image_info */
#define CL_IMAGE_D3D10_SUBRESOURCE_KHR 0x4016
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_D3D10_OBJECTS_KHR 0x4017
#define CL_COMMAND_RELEASE_D3D10_OBJECTS_KHR 0x4018
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D10KHR_fn)(
cl_platform_id platform,
cl_d3d10_device_source_khr d3d_device_source,
void * d3d_object,
cl_d3d10_device_set_khr d3d_device_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10BufferKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Buffer * resource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture2DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Texture2D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture3DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Texture3D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D10ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D10ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_0;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_D3D10_H */

131
include/triton/external/CL/cl_d3d11.h vendored Normal file
View File

@@ -0,0 +1,131 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_D3D11_H
#define __OPENCL_CL_D3D11_H
#include <d3d11.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************
* cl_khr_d3d11_sharing */
#define cl_khr_d3d11_sharing 1
typedef cl_uint cl_d3d11_device_source_khr;
typedef cl_uint cl_d3d11_device_set_khr;
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_D3D11_DEVICE_KHR -1006
#define CL_INVALID_D3D11_RESOURCE_KHR -1007
#define CL_D3D11_RESOURCE_ALREADY_ACQUIRED_KHR -1008
#define CL_D3D11_RESOURCE_NOT_ACQUIRED_KHR -1009
/* cl_d3d11_device_source */
#define CL_D3D11_DEVICE_KHR 0x4019
#define CL_D3D11_DXGI_ADAPTER_KHR 0x401A
/* cl_d3d11_device_set */
#define CL_PREFERRED_DEVICES_FOR_D3D11_KHR 0x401B
#define CL_ALL_DEVICES_FOR_D3D11_KHR 0x401C
/* cl_context_info */
#define CL_CONTEXT_D3D11_DEVICE_KHR 0x401D
#define CL_CONTEXT_D3D11_PREFER_SHARED_RESOURCES_KHR 0x402D
/* cl_mem_info */
#define CL_MEM_D3D11_RESOURCE_KHR 0x401E
/* cl_image_info */
#define CL_IMAGE_D3D11_SUBRESOURCE_KHR 0x401F
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_D3D11_OBJECTS_KHR 0x4020
#define CL_COMMAND_RELEASE_D3D11_OBJECTS_KHR 0x4021
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D11KHR_fn)(
cl_platform_id platform,
cl_d3d11_device_source_khr d3d_device_source,
void * d3d_object,
cl_d3d11_device_set_khr d3d_device_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11BufferKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Buffer * resource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture2DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Texture2D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture3DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Texture3D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D11ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D11ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_D3D11_H */

View File

@@ -0,0 +1,132 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H
#define __OPENCL_CL_DX9_MEDIA_SHARING_H
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************/
/* cl_khr_dx9_media_sharing */
#define cl_khr_dx9_media_sharing 1
typedef cl_uint cl_dx9_media_adapter_type_khr;
typedef cl_uint cl_dx9_media_adapter_set_khr;
#if defined(_WIN32)
#include <d3d9.h>
typedef struct _cl_dx9_surface_info_khr
{
IDirect3DSurface9 *resource;
HANDLE shared_handle;
} cl_dx9_surface_info_khr;
#endif
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_DX9_MEDIA_ADAPTER_KHR -1010
#define CL_INVALID_DX9_MEDIA_SURFACE_KHR -1011
#define CL_DX9_MEDIA_SURFACE_ALREADY_ACQUIRED_KHR -1012
#define CL_DX9_MEDIA_SURFACE_NOT_ACQUIRED_KHR -1013
/* cl_media_adapter_type_khr */
#define CL_ADAPTER_D3D9_KHR 0x2020
#define CL_ADAPTER_D3D9EX_KHR 0x2021
#define CL_ADAPTER_DXVA_KHR 0x2022
/* cl_media_adapter_set_khr */
#define CL_PREFERRED_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2023
#define CL_ALL_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2024
/* cl_context_info */
#define CL_CONTEXT_ADAPTER_D3D9_KHR 0x2025
#define CL_CONTEXT_ADAPTER_D3D9EX_KHR 0x2026
#define CL_CONTEXT_ADAPTER_DXVA_KHR 0x2027
/* cl_mem_info */
#define CL_MEM_DX9_MEDIA_ADAPTER_TYPE_KHR 0x2028
#define CL_MEM_DX9_MEDIA_SURFACE_INFO_KHR 0x2029
/* cl_image_info */
#define CL_IMAGE_DX9_MEDIA_PLANE_KHR 0x202A
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_DX9_MEDIA_SURFACES_KHR 0x202B
#define CL_COMMAND_RELEASE_DX9_MEDIA_SURFACES_KHR 0x202C
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromDX9MediaAdapterKHR_fn)(
cl_platform_id platform,
cl_uint num_media_adapters,
cl_dx9_media_adapter_type_khr * media_adapter_type,
void * media_adapters,
cl_dx9_media_adapter_set_khr media_adapter_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceKHR_fn)(
cl_context context,
cl_mem_flags flags,
cl_dx9_media_adapter_type_khr adapter_type,
void * surface_info,
cl_uint plane,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9MediaSurfacesKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9MediaSurfacesKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_H */

View File

@@ -0,0 +1,182 @@
/**********************************************************************************
* Copyright (c) 2008-2016 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2016 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_dx9_media_sharing_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H
#define __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include <d3d9.h>
#include <dxvahd.h>
#include <wtypes.h>
#include <d3d9types.h>
#ifdef __cplusplus
extern "C" {
#endif
/***************************************
* cl_intel_dx9_media_sharing extension *
****************************************/
#define cl_intel_dx9_media_sharing 1
typedef cl_uint cl_dx9_device_source_intel;
typedef cl_uint cl_dx9_device_set_intel;
/* error codes */
#define CL_INVALID_DX9_DEVICE_INTEL -1010
#define CL_INVALID_DX9_RESOURCE_INTEL -1011
#define CL_DX9_RESOURCE_ALREADY_ACQUIRED_INTEL -1012
#define CL_DX9_RESOURCE_NOT_ACQUIRED_INTEL -1013
/* cl_dx9_device_source_intel */
#define CL_D3D9_DEVICE_INTEL 0x4022
#define CL_D3D9EX_DEVICE_INTEL 0x4070
#define CL_DXVA_DEVICE_INTEL 0x4071
/* cl_dx9_device_set_intel */
#define CL_PREFERRED_DEVICES_FOR_DX9_INTEL 0x4024
#define CL_ALL_DEVICES_FOR_DX9_INTEL 0x4025
/* cl_context_info */
#define CL_CONTEXT_D3D9_DEVICE_INTEL 0x4026
#define CL_CONTEXT_D3D9EX_DEVICE_INTEL 0x4072
#define CL_CONTEXT_DXVA_DEVICE_INTEL 0x4073
/* cl_mem_info */
#define CL_MEM_DX9_RESOURCE_INTEL 0x4027
#define CL_MEM_DX9_SHARED_HANDLE_INTEL 0x4074
/* cl_image_info */
#define CL_IMAGE_DX9_PLANE_INTEL 0x4075
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_DX9_OBJECTS_INTEL 0x402A
#define CL_COMMAND_RELEASE_DX9_OBJECTS_INTEL 0x402B
/******************************************************************************/
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceIDsFromDX9INTEL(
cl_platform_id /* platform */,
cl_dx9_device_source_intel /* dx9_device_source */,
void* /* dx9_object */,
cl_dx9_device_set_intel /* dx9_device_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL* clGetDeviceIDsFromDX9INTEL_fn)(
cl_platform_id /* platform */,
cl_dx9_device_source_intel /* dx9_device_source */,
void* /* dx9_object */,
cl_dx9_device_set_intel /* dx9_device_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromDX9MediaSurfaceINTEL(
cl_context /* context */,
cl_mem_flags /* flags */,
IDirect3DSurface9* /* resource */,
HANDLE /* sharedHandle */,
UINT /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceINTEL_fn)(
cl_context /* context */,
cl_mem_flags /* flags */,
IDirect3DSurface9* /* resource */,
HANDLE /* sharedHandle */,
UINT /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireDX9ObjectsINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9ObjectsINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseDX9ObjectsINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9ObjectsINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H */

136
include/triton/external/CL/cl_egl.h vendored Normal file
View File

@@ -0,0 +1,136 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
#ifndef __OPENCL_CL_EGL_H
#define __OPENCL_CL_EGL_H
#ifdef __APPLE__
#else
#include "cl.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* Command type for events created with clEnqueueAcquireEGLObjectsKHR */
#define CL_COMMAND_EGL_FENCE_SYNC_OBJECT_KHR 0x202F
#define CL_COMMAND_ACQUIRE_EGL_OBJECTS_KHR 0x202D
#define CL_COMMAND_RELEASE_EGL_OBJECTS_KHR 0x202E
/* Error type for clCreateFromEGLImageKHR */
#define CL_INVALID_EGL_OBJECT_KHR -1093
#define CL_EGL_RESOURCE_NOT_ACQUIRED_KHR -1092
/* CLeglImageKHR is an opaque handle to an EGLImage */
typedef void* CLeglImageKHR;
/* CLeglDisplayKHR is an opaque handle to an EGLDisplay */
typedef void* CLeglDisplayKHR;
/* CLeglSyncKHR is an opaque handle to an EGLSync object */
typedef void* CLeglSyncKHR;
/* properties passed to clCreateFromEGLImageKHR */
typedef intptr_t cl_egl_image_properties_khr;
#define cl_khr_egl_image 1
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromEGLImageKHR(cl_context /* context */,
CLeglDisplayKHR /* egldisplay */,
CLeglImageKHR /* eglimage */,
cl_mem_flags /* flags */,
const cl_egl_image_properties_khr * /* properties */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromEGLImageKHR_fn)(
cl_context context,
CLeglDisplayKHR egldisplay,
CLeglImageKHR eglimage,
cl_mem_flags flags,
const cl_egl_image_properties_khr * properties,
cl_int * errcode_ret);
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireEGLObjectsKHR(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireEGLObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event);
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseEGLObjectsKHR(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseEGLObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event);
#define cl_khr_egl_event 1
extern CL_API_ENTRY cl_event CL_API_CALL
clCreateEventFromEGLSyncKHR(cl_context /* context */,
CLeglSyncKHR /* sync */,
CLeglDisplayKHR /* display */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_event (CL_API_CALL *clCreateEventFromEGLSyncKHR_fn)(
cl_context context,
CLeglSyncKHR sync,
CLeglDisplayKHR display,
cl_int * errcode_ret);
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_EGL_H */

670
include/triton/external/CL/cl_ext.h vendored Normal file
View File

@@ -0,0 +1,670 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/* $Revision: 11928 $ on $Date: 2010-07-13 09:04:56 -0700 (Tue, 13 Jul 2010) $ */
/* cl_ext.h contains OpenCL extensions which don't have external */
/* (OpenGL, D3D) dependencies. */
#ifndef __CL_EXT_H
#define __CL_EXT_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <AvailabilityMacros.h>
#else
#include "cl.h"
#endif
/* cl_khr_fp64 extension - no extension #define since it has no functions */
#define CL_DEVICE_DOUBLE_FP_CONFIG 0x1032
/* cl_khr_fp16 extension - no extension #define since it has no functions */
#define CL_DEVICE_HALF_FP_CONFIG 0x1033
/* Memory object destruction
*
* Apple extension for use to manage externally allocated buffers used with cl_mem objects with CL_MEM_USE_HOST_PTR
*
* Registers a user callback function that will be called when the memory object is deleted and its resources
* freed. Each call to clSetMemObjectCallbackFn registers the specified user callback function on a callback
* stack associated with memobj. The registered user callback functions are called in the reverse order in
* which they were registered. The user callback functions are called and then the memory object is deleted
* and its resources freed. This provides a mechanism for the application (and libraries) using memobj to be
* notified when the memory referenced by host_ptr, specified when the memory object is created and used as
* the storage bits for the memory object, can be reused or freed.
*
* The application may not call CL api's with the cl_mem object passed to the pfn_notify.
*
* Please check for the "cl_APPLE_SetMemObjectDestructor" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS)
* before using.
*/
#define cl_APPLE_SetMemObjectDestructor 1
cl_int CL_API_ENTRY clSetMemObjectDestructorAPPLE( cl_mem /* memobj */,
void (* /*pfn_notify*/)( cl_mem /* memobj */, void* /*user_data*/),
void * /*user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* Context Logging Functions
*
* The next three convenience functions are intended to be used as the pfn_notify parameter to clCreateContext().
* Please check for the "cl_APPLE_ContextLoggingFunctions" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS)
* before using.
*
* clLogMessagesToSystemLog fowards on all log messages to the Apple System Logger
*/
#define cl_APPLE_ContextLoggingFunctions 1
extern void CL_API_ENTRY clLogMessagesToSystemLogAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* clLogMessagesToStdout sends all log messages to the file descriptor stdout */
extern void CL_API_ENTRY clLogMessagesToStdoutAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* clLogMessagesToStderr sends all log messages to the file descriptor stderr */
extern void CL_API_ENTRY clLogMessagesToStderrAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/************************
* cl_khr_icd extension *
************************/
#define cl_khr_icd 1
/* cl_platform_info */
#define CL_PLATFORM_ICD_SUFFIX_KHR 0x0920
/* Additional Error Codes */
#define CL_PLATFORM_NOT_FOUND_KHR -1001
extern CL_API_ENTRY cl_int CL_API_CALL
clIcdGetPlatformIDsKHR(cl_uint /* num_entries */,
cl_platform_id * /* platforms */,
cl_uint * /* num_platforms */);
typedef CL_API_ENTRY cl_int (CL_API_CALL *clIcdGetPlatformIDsKHR_fn)(
cl_uint /* num_entries */,
cl_platform_id * /* platforms */,
cl_uint * /* num_platforms */);
/* Extension: cl_khr_image2D_buffer
*
* This extension allows a 2D image to be created from a cl_mem buffer without a copy.
* The type associated with a 2D image created from a buffer in an OpenCL program is image2d_t.
* Both the sampler and sampler-less read_image built-in functions are supported for 2D images
* and 2D images created from a buffer. Similarly, the write_image built-ins are also supported
* for 2D images created from a buffer.
*
* When the 2D image from buffer is created, the client must specify the width,
* height, image format (i.e. channel order and channel data type) and optionally the row pitch
*
* The pitch specified must be a multiple of CL_DEVICE_IMAGE_PITCH_ALIGNMENT pixels.
* The base address of the buffer must be aligned to CL_DEVICE_IMAGE_BASE_ADDRESS_ALIGNMENT pixels.
*/
/*************************************
* cl_khr_initalize_memory extension *
*************************************/
#define CL_CONTEXT_MEMORY_INITIALIZE_KHR 0x2030
/**************************************
* cl_khr_terminate_context extension *
**************************************/
#define CL_DEVICE_TERMINATE_CAPABILITY_KHR 0x2031
#define CL_CONTEXT_TERMINATE_KHR 0x2032
#define cl_khr_terminate_context 1
extern CL_API_ENTRY cl_int CL_API_CALL clTerminateContextKHR(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clTerminateContextKHR_fn)(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2;
/*
* Extension: cl_khr_spir
*
* This extension adds support to create an OpenCL program object from a
* Standard Portable Intermediate Representation (SPIR) instance
*/
#define CL_DEVICE_SPIR_VERSIONS 0x40E0
#define CL_PROGRAM_BINARY_TYPE_INTERMEDIATE 0x40E1
/*****************************************
* cl_khr_create_command_queue extension *
*****************************************/
#define cl_khr_create_command_queue 1
typedef cl_bitfield cl_queue_properties_khr;
extern CL_API_ENTRY cl_command_queue CL_API_CALL
clCreateCommandQueueWithPropertiesKHR( cl_context /* context */,
cl_device_id /* device */,
const cl_queue_properties_khr* /* properties */,
cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_command_queue
(CL_API_CALL *clCreateCommandQueueWithPropertiesKHR_fn)( cl_context /* context */,
cl_device_id /* device */,
const cl_queue_properties_khr* /* properties */,
cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2;
/******************************************
* cl_nv_device_attribute_query extension *
******************************************/
/* cl_nv_device_attribute_query extension - no extension #define since it has no functions */
#define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#define CL_DEVICE_REGISTERS_PER_BLOCK_NV 0x4002
#define CL_DEVICE_WARP_SIZE_NV 0x4003
#define CL_DEVICE_GPU_OVERLAP_NV 0x4004
#define CL_DEVICE_KERNEL_EXEC_TIMEOUT_NV 0x4005
#define CL_DEVICE_INTEGRATED_MEMORY_NV 0x4006
/*********************************
* cl_amd_device_memory_flags *
*********************************/
#define cl_amd_device_memory_flags 1
#define CL_MEM_USE_PERSISTENT_MEM_AMD (1 << 6) // Alloc from GPU's CPU visible heap
/* cl_device_info */
#define CL_DEVICE_MAX_ATOMIC_COUNTERS_EXT 0x4032
/*********************************
* cl_amd_device_attribute_query *
*********************************/
#define CL_DEVICE_PROFILING_TIMER_OFFSET_AMD 0x4036
#define CL_DEVICE_TOPOLOGY_AMD 0x4037
#define CL_DEVICE_BOARD_NAME_AMD 0x4038
#define CL_DEVICE_GLOBAL_FREE_MEMORY_AMD 0x4039
#define CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD 0x4040
#define CL_DEVICE_SIMD_WIDTH_AMD 0x4041
#define CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD 0x4042
#define CL_DEVICE_WAVEFRONT_WIDTH_AMD 0x4043
#define CL_DEVICE_GLOBAL_MEM_CHANNELS_AMD 0x4044
#define CL_DEVICE_GLOBAL_MEM_CHANNEL_BANKS_AMD 0x4045
#define CL_DEVICE_GLOBAL_MEM_CHANNEL_BANK_WIDTH_AMD 0x4046
#define CL_DEVICE_LOCAL_MEM_SIZE_PER_COMPUTE_UNIT_AMD 0x4047
#define CL_DEVICE_LOCAL_MEM_BANKS_AMD 0x4048
typedef union
{
struct { cl_uint type; cl_uint data[5]; } raw;
struct { cl_uint type; cl_char unused[17]; cl_char bus; cl_char device; cl_char function; } pcie;
} cl_device_topology_amd;
#define CL_DEVICE_TOPOLOGY_TYPE_PCIE_AMD 1
/**************************
* cl_amd_offline_devices *
**************************/
#define CL_CONTEXT_OFFLINE_DEVICES_AMD 0x403F
/*********************************
* cl_arm_printf extension
*********************************/
#define CL_PRINTF_CALLBACK_ARM 0x40B0
#define CL_PRINTF_BUFFERSIZE_ARM 0x40B1
#ifdef CL_VERSION_1_1
/***********************************
* cl_ext_device_fission extension *
***********************************/
#define cl_ext_device_fission 1
extern CL_API_ENTRY cl_int CL_API_CALL
clReleaseDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
(CL_API_CALL *clReleaseDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clRetainDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
(CL_API_CALL *clRetainDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef cl_ulong cl_device_partition_property_ext;
extern CL_API_ENTRY cl_int CL_API_CALL
clCreateSubDevicesEXT( cl_device_id /*in_device*/,
const cl_device_partition_property_ext * /* properties */,
cl_uint /*num_entries*/,
cl_device_id * /*out_devices*/,
cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
( CL_API_CALL * clCreateSubDevicesEXT_fn)( cl_device_id /*in_device*/,
const cl_device_partition_property_ext * /* properties */,
cl_uint /*num_entries*/,
cl_device_id * /*out_devices*/,
cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1;
/* cl_device_partition_property_ext */
#define CL_DEVICE_PARTITION_EQUALLY_EXT 0x4050
#define CL_DEVICE_PARTITION_BY_COUNTS_EXT 0x4051
#define CL_DEVICE_PARTITION_BY_NAMES_EXT 0x4052
#define CL_DEVICE_PARTITION_BY_AFFINITY_DOMAIN_EXT 0x4053
/* clDeviceGetInfo selectors */
#define CL_DEVICE_PARENT_DEVICE_EXT 0x4054
#define CL_DEVICE_PARTITION_TYPES_EXT 0x4055
#define CL_DEVICE_AFFINITY_DOMAINS_EXT 0x4056
#define CL_DEVICE_REFERENCE_COUNT_EXT 0x4057
#define CL_DEVICE_PARTITION_STYLE_EXT 0x4058
/* error codes */
#define CL_DEVICE_PARTITION_FAILED_EXT -1057
#define CL_INVALID_PARTITION_COUNT_EXT -1058
#define CL_INVALID_PARTITION_NAME_EXT -1059
/* CL_AFFINITY_DOMAINs */
#define CL_AFFINITY_DOMAIN_L1_CACHE_EXT 0x1
#define CL_AFFINITY_DOMAIN_L2_CACHE_EXT 0x2
#define CL_AFFINITY_DOMAIN_L3_CACHE_EXT 0x3
#define CL_AFFINITY_DOMAIN_L4_CACHE_EXT 0x4
#define CL_AFFINITY_DOMAIN_NUMA_EXT 0x10
#define CL_AFFINITY_DOMAIN_NEXT_FISSIONABLE_EXT 0x100
/* cl_device_partition_property_ext list terminators */
#define CL_PROPERTIES_LIST_END_EXT ((cl_device_partition_property_ext) 0)
#define CL_PARTITION_BY_COUNTS_LIST_END_EXT ((cl_device_partition_property_ext) 0)
#define CL_PARTITION_BY_NAMES_LIST_END_EXT ((cl_device_partition_property_ext) 0 - 1)
/* cl_ext_atomic_counters_32 and cl_ext_atomic_counters_64 extensions
* no extension #define since they have no functions
*/
#define CL_DEVICE_MAX_ATOMIC_COUNTERS_EXT 0x4032
/*********************************
* cl_qcom_ext_host_ptr extension
*********************************/
#define CL_MEM_EXT_HOST_PTR_QCOM (1 << 29)
#define CL_DEVICE_EXT_MEM_PADDING_IN_BYTES_QCOM 0x40A0
#define CL_DEVICE_PAGE_SIZE_QCOM 0x40A1
#define CL_IMAGE_ROW_ALIGNMENT_QCOM 0x40A2
#define CL_IMAGE_SLICE_ALIGNMENT_QCOM 0x40A3
#define CL_MEM_HOST_UNCACHED_QCOM 0x40A4
#define CL_MEM_HOST_WRITEBACK_QCOM 0x40A5
#define CL_MEM_HOST_WRITETHROUGH_QCOM 0x40A6
#define CL_MEM_HOST_WRITE_COMBINING_QCOM 0x40A7
typedef cl_uint cl_image_pitch_info_qcom;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceImageInfoQCOM(cl_device_id device,
size_t image_width,
size_t image_height,
const cl_image_format *image_format,
cl_image_pitch_info_qcom param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret);
typedef struct _cl_mem_ext_host_ptr
{
/* Type of external memory allocation. */
/* Legal values will be defined in layered extensions. */
cl_uint allocation_type;
/* Host cache policy for this external memory allocation. */
cl_uint host_cache_policy;
} cl_mem_ext_host_ptr;
/*********************************
* cl_qcom_ion_host_ptr extension
*********************************/
#define CL_MEM_ION_HOST_PTR_QCOM 0x40A8
typedef struct _cl_mem_ion_host_ptr
{
/* Type of external memory allocation. */
/* Must be CL_MEM_ION_HOST_PTR_QCOM for ION allocations. */
cl_mem_ext_host_ptr ext_host_ptr;
/* ION file descriptor */
int ion_filedesc;
/* Host pointer to the ION allocated memory */
void* ion_hostptr;
} cl_mem_ion_host_ptr;
#endif /* CL_VERSION_1_1 */
#if defined(CL_VERSION_1_2)
/******************************************
* cl_img_yuv_image extension *
******************************************/
/* Image formats used in clCreateImage */
#define CL_NV21_IMG 0x40D0
#define CL_YV12_IMG 0x40D1
/******************************************
* cl_img_cached_allocations extension *
******************************************/
/* Flag values used by clCreteBuffer */
#define CL_MEM_USE_UNCACHED_CPU_MEMORY_IMG (1 << 26)
#define CL_MEM_USE_CACHED_CPU_MEMORY_IMG (1 << 27)
/******************************************
* cl_img_use_gralloc_ptr extension *
******************************************/
/* Flag values used by clCreteBuffer */
#define CL_MEM_USE_GRALLOC_PTR_IMG (1 << 28)
/* To be used by clGetEventInfo: */
#define CL_COMMAND_ACQUIRE_GRALLOC_OBJECTS_IMG 0x40D2
#define CL_COMMAND_RELEASE_GRALLOC_OBJECTS_IMG 0x40D3
/* Error code from clEnqueueReleaseGrallocObjectsIMG */
#define CL_GRALLOC_RESOURCE_NOT_ACQUIRED_IMG 0x40D4
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireGrallocObjectsIMG(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseGrallocObjectsIMG(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
#endif /* CL_VERSION_1_2 */
#ifdef CL_VERSION_2_0
/*********************************
* cl_khr_subgroups extension
*********************************/
#define cl_khr_subgroups 1
/* cl_kernel_sub_group_info is declared in CL.h. */
/* cl_kernel_sub_group_info */
#define CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR 0x2033
#define CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR 0x2034
extern CL_API_ENTRY cl_int CL_API_CALL
clGetKernelSubGroupInfoKHR(cl_kernel /* in_kernel */,
cl_device_id /*in_device*/,
cl_kernel_sub_group_info /* param_name */,
size_t /*input_value_size*/,
const void * /*input_value*/,
size_t /*param_value_size*/,
void* /*param_value*/,
size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED;
typedef CL_API_ENTRY cl_int
( CL_API_CALL * clGetKernelSubGroupInfoKHR_fn)(cl_kernel /* in_kernel */,
cl_device_id /*in_device*/,
cl_kernel_sub_group_info /* param_name */,
size_t /*input_value_size*/,
const void * /*input_value*/,
size_t /*param_value_size*/,
void* /*param_value*/,
size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED;
#endif /* CL_VERSION_2_0 */
#ifdef CL_VERSION_2_1
/*********************************
* cl_khr_priority_hints extension
*********************************/
#define cl_khr_priority_hints 1
typedef cl_uint cl_queue_priority_khr;
/* cl_command_queue_properties */
#define CL_QUEUE_PRIORITY_KHR 0x1096
/* cl_queue_priority_khr */
#define CL_QUEUE_PRIORITY_HIGH_KHR (1<<0)
#define CL_QUEUE_PRIORITY_MED_KHR (1<<1)
#define CL_QUEUE_PRIORITY_LOW_KHR (1<<2)
#endif /* CL_VERSION_2_1 */
#ifdef CL_VERSION_2_1
/*********************************
* cl_khr_throttle_hints extension
*********************************/
#define cl_khr_throttle_hints 1
typedef cl_uint cl_queue_throttle_khr;
/* cl_command_queue_properties */
#define CL_QUEUE_THROTTLE_KHR 0x1097
/* cl_queue_throttle_khr */
#define CL_QUEUE_THROTTLE_HIGH_KHR (1<<0)
#define CL_QUEUE_THROTTLE_MED_KHR (1<<1)
#define CL_QUEUE_THROTTLE_LOW_KHR (1<<2)
#endif /* CL_VERSION_2_1 */
#ifdef CL_VERSION_2_2
/*********************************
* cl_khr_subgroup_named_barrier
*********************************/
#define cl_khr_subgroup_named_barrier 1
/* cl_device_info */
#define CL_DEVICE_MAX_NAMED_BARRIER_COUNT_KHR 0x2035
#endif /* CL_VERSION_2_2 */
/**********************************
* cl_arm_import_memory extension *
**********************************/
#ifdef CL_VERSION_1_0
typedef intptr_t cl_import_properties_arm;
/* Default and valid proporties name for cl_arm_import_memory */
#define CL_IMPORT_TYPE_ARM 0x40B2
/* Host process memory type default value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_HOST_ARM 0x40B3
/* DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_DMA_BUF_ARM 0x40B4
/* Secure DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_SECURE_ARM 0x40B5
/* This extension adds a new function that allows for direct memory import into
* OpenCL via the clImportMemoryARM function.
*
* Memory imported through this interface will be mapped into the device's page
* tables directly, providing zero copy access. It will never fall back to copy
* operations and aliased buffers.
*
* Types of memory supported for import are specified as additional extension
* strings.
*
* This extension produces cl_mem allocations which are compatible with all other
* users of cl_mem in the standard API.
*
* This extension maps pages with the same properties as the normal buffer creation
* function clCreateBuffer.
*/
extern CL_API_ENTRY cl_mem CL_API_CALL
clImportMemoryARM( cl_context context,
cl_mem_flags flags,
const cl_import_properties_arm *properties,
void *memory,
size_t size,
cl_int *errcode_ret) CL_EXT_SUFFIX__VERSION_1_0;
#endif /* CL_VERSION_1_0 */
/******************************************
* cl_arm_shared_virtual_memory extension *
******************************************/
#ifdef CL_VERSION_1_2
/* Used by clGetDeviceInfo */
#define CL_DEVICE_SVM_CAPABILITIES_ARM 0x40B6
/* Used by clGetMemObjectInfo */
#define CL_MEM_USES_SVM_POINTER_ARM 0x40B7
/* Used by clSetKernelExecInfoARM: */
#define CL_KERNEL_EXEC_INFO_SVM_PTRS_ARM 0x40B8
#define CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM_ARM 0x40B9
/* To be used by clGetEventInfo: */
#define CL_COMMAND_SVM_FREE_ARM 0x40BA
#define CL_COMMAND_SVM_MEMCPY_ARM 0x40BB
#define CL_COMMAND_SVM_MEMFILL_ARM 0x40BC
#define CL_COMMAND_SVM_MAP_ARM 0x40BD
#define CL_COMMAND_SVM_UNMAP_ARM 0x40BE
/* Flag values returned by clGetDeviceInfo with CL_DEVICE_SVM_CAPABILITIES_ARM as the param_name. */
#define CL_DEVICE_SVM_COARSE_GRAIN_BUFFER_ARM (1 << 0)
#define CL_DEVICE_SVM_FINE_GRAIN_BUFFER_ARM (1 << 1)
#define CL_DEVICE_SVM_FINE_GRAIN_SYSTEM_ARM (1 << 2)
#define CL_DEVICE_SVM_ATOMICS_ARM (1 << 3)
/* Flag values used by clSVMAllocARM: */
#define CL_MEM_SVM_FINE_GRAIN_BUFFER_ARM (1 << 10)
#define CL_MEM_SVM_ATOMICS_ARM (1 << 11)
typedef cl_bitfield cl_svm_mem_flags_arm;
typedef cl_uint cl_kernel_exec_info_arm;
typedef cl_bitfield cl_device_svm_capabilities_arm;
extern CL_API_ENTRY void * CL_API_CALL
clSVMAllocARM(cl_context /* context */,
cl_svm_mem_flags_arm /* flags */,
size_t /* size */,
cl_uint /* alignment */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY void CL_API_CALL
clSVMFreeARM(cl_context /* context */,
void * /* svm_pointer */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMFreeARM(cl_command_queue /* command_queue */,
cl_uint /* num_svm_pointers */,
void *[] /* svm_pointers[] */,
void (CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */,
cl_uint /* num_svm_pointers */,
void *[] /* svm_pointers[] */,
void * /* user_data */),
void * /* user_data */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMemcpyARM(cl_command_queue /* command_queue */,
cl_bool /* blocking_copy */,
void * /* dst_ptr */,
const void * /* src_ptr */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMemFillARM(cl_command_queue /* command_queue */,
void * /* svm_ptr */,
const void * /* pattern */,
size_t /* pattern_size */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMapARM(cl_command_queue /* command_queue */,
cl_bool /* blocking_map */,
cl_map_flags /* flags */,
void * /* svm_ptr */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMUnmapARM(cl_command_queue /* command_queue */,
void * /* svm_ptr */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clSetKernelArgSVMPointerARM(cl_kernel /* kernel */,
cl_uint /* arg_index */,
const void * /* arg_value */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clSetKernelExecInfoARM(cl_kernel /* kernel */,
cl_kernel_exec_info_arm /* param_name */,
size_t /* param_value_size */,
const void * /* param_value */) CL_EXT_SUFFIX__VERSION_1_2;
#endif /* CL_VERSION_1_2 */
#ifdef __cplusplus
}
#endif
#endif /* __CL_EXT_H */

View File

@@ -0,0 +1,429 @@
/*******************************************************************************
* Copyright (c) 2008-2017 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2017 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_ext_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __CL_EXT_INTEL_H
#define __CL_EXT_INTEL_H
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <OpenCL/cl_platform.h>
#else
#include "cl.h"
#include "cl_platform.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
/***************************************
* cl_intel_thread_local_exec extension *
****************************************/
#define cl_intel_thread_local_exec 1
#define CL_QUEUE_THREAD_LOCAL_EXEC_ENABLE_INTEL (((cl_bitfield)1) << 31)
/***********************************************
* cl_intel_device_partition_by_names extension *
************************************************/
#define cl_intel_device_partition_by_names 1
#define CL_DEVICE_PARTITION_BY_NAMES_INTEL 0x4052
#define CL_PARTITION_BY_NAMES_LIST_END_INTEL -1
/************************************************
* cl_intel_accelerator extension *
* cl_intel_motion_estimation extension *
* cl_intel_advanced_motion_estimation extension *
*************************************************/
#define cl_intel_accelerator 1
#define cl_intel_motion_estimation 1
#define cl_intel_advanced_motion_estimation 1
typedef struct _cl_accelerator_intel* cl_accelerator_intel;
typedef cl_uint cl_accelerator_type_intel;
typedef cl_uint cl_accelerator_info_intel;
typedef struct _cl_motion_estimation_desc_intel {
cl_uint mb_block_type;
cl_uint subpixel_mode;
cl_uint sad_adjust_mode;
cl_uint search_path_type;
} cl_motion_estimation_desc_intel;
/* error codes */
#define CL_INVALID_ACCELERATOR_INTEL -1094
#define CL_INVALID_ACCELERATOR_TYPE_INTEL -1095
#define CL_INVALID_ACCELERATOR_DESCRIPTOR_INTEL -1096
#define CL_ACCELERATOR_TYPE_NOT_SUPPORTED_INTEL -1097
/* cl_accelerator_type_intel */
#define CL_ACCELERATOR_TYPE_MOTION_ESTIMATION_INTEL 0x0
/* cl_accelerator_info_intel */
#define CL_ACCELERATOR_DESCRIPTOR_INTEL 0x4090
#define CL_ACCELERATOR_REFERENCE_COUNT_INTEL 0x4091
#define CL_ACCELERATOR_CONTEXT_INTEL 0x4092
#define CL_ACCELERATOR_TYPE_INTEL 0x4093
/* cl_motion_detect_desc_intel flags */
#define CL_ME_MB_TYPE_16x16_INTEL 0x0
#define CL_ME_MB_TYPE_8x8_INTEL 0x1
#define CL_ME_MB_TYPE_4x4_INTEL 0x2
#define CL_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0
#define CL_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1
#define CL_ME_SUBPIXEL_MODE_QPEL_INTEL 0x2
#define CL_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0
#define CL_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x1
#define CL_ME_SEARCH_PATH_RADIUS_2_2_INTEL 0x0
#define CL_ME_SEARCH_PATH_RADIUS_4_4_INTEL 0x1
#define CL_ME_SEARCH_PATH_RADIUS_16_12_INTEL 0x5
#define CL_ME_SKIP_BLOCK_TYPE_16x16_INTEL 0x0
#define CL_ME_CHROMA_INTRA_PREDICT_ENABLED_INTEL 0x1
#define CL_ME_LUMA_INTRA_PREDICT_ENABLED_INTEL 0x2
#define CL_ME_SKIP_BLOCK_TYPE_8x8_INTEL 0x4
#define CL_ME_FORWARD_INPUT_MODE_INTEL 0x1
#define CL_ME_BACKWARD_INPUT_MODE_INTEL 0x2
#define CL_ME_BIDIRECTION_INPUT_MODE_INTEL 0x3
#define CL_ME_BIDIR_WEIGHT_QUARTER_INTEL 16
#define CL_ME_BIDIR_WEIGHT_THIRD_INTEL 21
#define CL_ME_BIDIR_WEIGHT_HALF_INTEL 32
#define CL_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 43
#define CL_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 48
#define CL_ME_COST_PENALTY_NONE_INTEL 0x0
#define CL_ME_COST_PENALTY_LOW_INTEL 0x1
#define CL_ME_COST_PENALTY_NORMAL_INTEL 0x2
#define CL_ME_COST_PENALTY_HIGH_INTEL 0x3
#define CL_ME_COST_PRECISION_QPEL_INTEL 0x0
#define CL_ME_COST_PRECISION_HPEL_INTEL 0x1
#define CL_ME_COST_PRECISION_PEL_INTEL 0x2
#define CL_ME_COST_PRECISION_DPEL_INTEL 0x3
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2
#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3
#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4
#define CL_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8
#define CL_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0
#define CL_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2
#define CL_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3
/* cl_device_info */
#define CL_DEVICE_ME_VERSION_INTEL 0x407E
#define CL_ME_VERSION_LEGACY_INTEL 0x0
#define CL_ME_VERSION_ADVANCED_VER_1_INTEL 0x1
#define CL_ME_VERSION_ADVANCED_VER_2_INTEL 0x2
extern CL_API_ENTRY cl_accelerator_intel CL_API_CALL
clCreateAcceleratorINTEL(
cl_context /* context */,
cl_accelerator_type_intel /* accelerator_type */,
size_t /* descriptor_size */,
const void* /* descriptor */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_accelerator_intel (CL_API_CALL *clCreateAcceleratorINTEL_fn)(
cl_context /* context */,
cl_accelerator_type_intel /* accelerator_type */,
size_t /* descriptor_size */,
const void* /* descriptor */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetAcceleratorInfoINTEL(
cl_accelerator_intel /* accelerator */,
cl_accelerator_info_intel /* param_name */,
size_t /* param_value_size */,
void* /* param_value */,
size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetAcceleratorInfoINTEL_fn)(
cl_accelerator_intel /* accelerator */,
cl_accelerator_info_intel /* param_name */,
size_t /* param_value_size */,
void* /* param_value */,
size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clRetainAcceleratorINTEL(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clRetainAcceleratorINTEL_fn)(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clReleaseAcceleratorINTEL(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clReleaseAcceleratorINTEL_fn)(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
/******************************************
* cl_intel_simultaneous_sharing extension *
*******************************************/
#define cl_intel_simultaneous_sharing 1
#define CL_DEVICE_SIMULTANEOUS_INTEROPS_INTEL 0x4104
#define CL_DEVICE_NUM_SIMULTANEOUS_INTEROPS_INTEL 0x4105
/***********************************
* cl_intel_egl_image_yuv extension *
************************************/
#define cl_intel_egl_image_yuv 1
#define CL_EGL_YUV_PLANE_INTEL 0x4107
/********************************
* cl_intel_packed_yuv extension *
*********************************/
#define cl_intel_packed_yuv 1
#define CL_YUYV_INTEL 0x4076
#define CL_UYVY_INTEL 0x4077
#define CL_YVYU_INTEL 0x4078
#define CL_VYUY_INTEL 0x4079
/********************************************
* cl_intel_required_subgroup_size extension *
*********************************************/
#define cl_intel_required_subgroup_size 1
#define CL_DEVICE_SUB_GROUP_SIZES_INTEL 0x4108
#define CL_KERNEL_SPILL_MEM_SIZE_INTEL 0x4109
#define CL_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL 0x410A
/****************************************
* cl_intel_driver_diagnostics extension *
*****************************************/
#define cl_intel_driver_diagnostics 1
typedef cl_uint cl_diagnostics_verbose_level;
#define CL_CONTEXT_SHOW_DIAGNOSTICS_INTEL 0x4106
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_ALL_INTEL ( 0xff )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_GOOD_INTEL ( 1 )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_BAD_INTEL ( 1 << 1 )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_NEUTRAL_INTEL ( 1 << 2 )
/********************************
* cl_intel_planar_yuv extension *
*********************************/
#define CL_NV12_INTEL 0x410E
#define CL_MEM_NO_ACCESS_INTEL ( 1 << 24 )
#define CL_MEM_ACCESS_FLAGS_UNRESTRICTED_INTEL ( 1 << 25 )
#define CL_DEVICE_PLANAR_YUV_MAX_WIDTH_INTEL 0x417E
#define CL_DEVICE_PLANAR_YUV_MAX_HEIGHT_INTEL 0x417F
/*******************************************************
* cl_intel_device_side_avc_motion_estimation extension *
********************************************************/
#define CL_DEVICE_AVC_ME_VERSION_INTEL 0x410B
#define CL_DEVICE_AVC_ME_SUPPORTS_TEXTURE_SAMPLER_USE_INTEL 0x410C
#define CL_DEVICE_AVC_ME_SUPPORTS_PREEMPTION_INTEL 0x410D
#define CL_AVC_ME_VERSION_0_INTEL 0x0; // No support.
#define CL_AVC_ME_VERSION_1_INTEL 0x1; // First supported version.
#define CL_AVC_ME_MAJOR_16x16_INTEL 0x0
#define CL_AVC_ME_MAJOR_16x8_INTEL 0x1
#define CL_AVC_ME_MAJOR_8x16_INTEL 0x2
#define CL_AVC_ME_MAJOR_8x8_INTEL 0x3
#define CL_AVC_ME_MINOR_8x8_INTEL 0x0
#define CL_AVC_ME_MINOR_8x4_INTEL 0x1
#define CL_AVC_ME_MINOR_4x8_INTEL 0x2
#define CL_AVC_ME_MINOR_4x4_INTEL 0x3
#define CL_AVC_ME_MAJOR_FORWARD_INTEL 0x0
#define CL_AVC_ME_MAJOR_BACKWARD_INTEL 0x1
#define CL_AVC_ME_MAJOR_BIDIRECTIONAL_INTEL 0x2
#define CL_AVC_ME_PARTITION_MASK_ALL_INTEL 0x0
#define CL_AVC_ME_PARTITION_MASK_16x16_INTEL 0x7E
#define CL_AVC_ME_PARTITION_MASK_16x8_INTEL 0x7D
#define CL_AVC_ME_PARTITION_MASK_8x16_INTEL 0x7B
#define CL_AVC_ME_PARTITION_MASK_8x8_INTEL 0x77
#define CL_AVC_ME_PARTITION_MASK_8x4_INTEL 0x6F
#define CL_AVC_ME_PARTITION_MASK_4x8_INTEL 0x5F
#define CL_AVC_ME_PARTITION_MASK_4x4_INTEL 0x3F
#define CL_AVC_ME_SEARCH_WINDOW_EXHAUSTIVE_INTEL 0x0
#define CL_AVC_ME_SEARCH_WINDOW_SMALL_INTEL 0x1
#define CL_AVC_ME_SEARCH_WINDOW_TINY_INTEL 0x2
#define CL_AVC_ME_SEARCH_WINDOW_EXTRA_TINY_INTEL 0x3
#define CL_AVC_ME_SEARCH_WINDOW_DIAMOND_INTEL 0x4
#define CL_AVC_ME_SEARCH_WINDOW_LARGE_DIAMOND_INTEL 0x5
#define CL_AVC_ME_SEARCH_WINDOW_RESERVED0_INTEL 0x6
#define CL_AVC_ME_SEARCH_WINDOW_RESERVED1_INTEL 0x7
#define CL_AVC_ME_SEARCH_WINDOW_CUSTOM_INTEL 0x8
#define CL_AVC_ME_SEARCH_WINDOW_16x12_RADIUS_INTEL 0x9
#define CL_AVC_ME_SEARCH_WINDOW_4x4_RADIUS_INTEL 0x2
#define CL_AVC_ME_SEARCH_WINDOW_2x2_RADIUS_INTEL 0xa
#define CL_AVC_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0
#define CL_AVC_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x2
#define CL_AVC_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0
#define CL_AVC_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1
#define CL_AVC_ME_SUBPIXEL_MODE_QPEL_INTEL 0x3
#define CL_AVC_ME_COST_PRECISION_QPEL_INTEL 0x0
#define CL_AVC_ME_COST_PRECISION_HPEL_INTEL 0x1
#define CL_AVC_ME_COST_PRECISION_PEL_INTEL 0x2
#define CL_AVC_ME_COST_PRECISION_DPEL_INTEL 0x3
#define CL_AVC_ME_BIDIR_WEIGHT_QUARTER_INTEL 0x10
#define CL_AVC_ME_BIDIR_WEIGHT_THIRD_INTEL 0x15
#define CL_AVC_ME_BIDIR_WEIGHT_HALF_INTEL 0x20
#define CL_AVC_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 0x2B
#define CL_AVC_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 0x30
#define CL_AVC_ME_BORDER_REACHED_LEFT_INTEL 0x0
#define CL_AVC_ME_BORDER_REACHED_RIGHT_INTEL 0x2
#define CL_AVC_ME_BORDER_REACHED_TOP_INTEL 0x4
#define CL_AVC_ME_BORDER_REACHED_BOTTOM_INTEL 0x8
#define CL_AVC_ME_SKIP_BLOCK_PARTITION_16x16_INTEL 0x0
#define CL_AVC_ME_SKIP_BLOCK_PARTITION_8x8_INTEL 0x4000
#define CL_AVC_ME_SKIP_BLOCK_16x16_FORWARD_ENABLE_INTEL ( 0x1 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_16x16_BACKWARD_ENABLE_INTEL ( 0x2 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_16x16_DUAL_ENABLE_INTEL ( 0x3 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_FORWARD_ENABLE_INTEL ( 0x55 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_BACKWARD_ENABLE_INTEL ( 0xAA << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_DUAL_ENABLE_INTEL ( 0xFF << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_0_FORWARD_ENABLE_INTEL ( 0x1 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_0_BACKWARD_ENABLE_INTEL ( 0x2 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_1_FORWARD_ENABLE_INTEL ( 0x1 << 26 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_1_BACKWARD_ENABLE_INTEL ( 0x2 << 26 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_2_FORWARD_ENABLE_INTEL ( 0x1 << 28 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_2_BACKWARD_ENABLE_INTEL ( 0x2 << 28 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_3_FORWARD_ENABLE_INTEL ( 0x1 << 30 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_3_BACKWARD_ENABLE_INTEL ( 0x2 << 30 )
#define CL_AVC_ME_BLOCK_BASED_SKIP_4x4_INTEL 0x00
#define CL_AVC_ME_BLOCK_BASED_SKIP_8x8_INTEL 0x80
#define CL_AVC_ME_INTRA_16x16_INTEL 0x0
#define CL_AVC_ME_INTRA_8x8_INTEL 0x1
#define CL_AVC_ME_INTRA_4x4_INTEL 0x2
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_16x16_INTEL 0x6
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_8x8_INTEL 0x5
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_4x4_INTEL 0x3
#define CL_AVC_ME_INTRA_NEIGHBOR_LEFT_MASK_ENABLE_INTEL 0x60
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_MASK_ENABLE_INTEL 0x10
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_RIGHT_MASK_ENABLE_INTEL 0x8
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_LEFT_MASK_ENABLE_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3
#define CL_AVC_ME_FRAME_FORWARD_INTEL 0x1
#define CL_AVC_ME_FRAME_BACKWARD_INTEL 0x2
#define CL_AVC_ME_FRAME_DUAL_INTEL 0x3
#define CL_AVC_ME_SLICE_TYPE_PRED_INTEL 0x0
#define CL_AVC_ME_SLICE_TYPE_BPRED_INTEL 0x1
#define CL_AVC_ME_SLICE_TYPE_INTRA_INTEL 0x2
#define CL_AVC_ME_INTERLACED_SCAN_TOP_FIELD_INTEL 0x0
#define CL_AVC_ME_INTERLACED_SCAN_BOTTOM_FIELD_INTEL 0x1
#ifdef __cplusplus
}
#endif
#endif /* __CL_EXT_INTEL_H */

167
include/triton/external/CL/cl_gl.h vendored Normal file
View File

@@ -0,0 +1,167 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
#ifndef __OPENCL_CL_GL_H
#define __OPENCL_CL_GL_H
#ifdef __APPLE__
#include <OpenCL/cl.h>
#else
#include "cl.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef cl_uint cl_gl_object_type;
typedef cl_uint cl_gl_texture_info;
typedef cl_uint cl_gl_platform_info;
typedef struct __GLsync *cl_GLsync;
/* cl_gl_object_type = 0x2000 - 0x200F enum values are currently taken */
#define CL_GL_OBJECT_BUFFER 0x2000
#define CL_GL_OBJECT_TEXTURE2D 0x2001
#define CL_GL_OBJECT_TEXTURE3D 0x2002
#define CL_GL_OBJECT_RENDERBUFFER 0x2003
#define CL_GL_OBJECT_TEXTURE2D_ARRAY 0x200E
#define CL_GL_OBJECT_TEXTURE1D 0x200F
#define CL_GL_OBJECT_TEXTURE1D_ARRAY 0x2010
#define CL_GL_OBJECT_TEXTURE_BUFFER 0x2011
/* cl_gl_texture_info */
#define CL_GL_TEXTURE_TARGET 0x2004
#define CL_GL_MIPMAP_LEVEL 0x2005
#define CL_GL_NUM_SAMPLES 0x2012
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLBuffer(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLuint /* bufobj */,
int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLTexture(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLRenderbuffer(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLuint /* renderbuffer */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLObjectInfo(cl_mem /* memobj */,
cl_gl_object_type * /* gl_object_type */,
cl_GLuint * /* gl_object_name */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLTextureInfo(cl_mem /* memobj */,
cl_gl_texture_info /* param_name */,
size_t /* param_value_size */,
void * /* param_value */,
size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireGLObjects(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseGLObjects(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
/* Deprecated OpenCL 1.1 APIs */
extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL
clCreateFromGLTexture2D(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED;
extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL
clCreateFromGLTexture3D(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED;
/* cl_khr_gl_sharing extension */
#define cl_khr_gl_sharing 1
typedef cl_uint cl_gl_context_info;
/* Additional Error Codes */
#define CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR -1000
/* cl_gl_context_info */
#define CL_CURRENT_DEVICE_FOR_GL_CONTEXT_KHR 0x2006
#define CL_DEVICES_FOR_GL_CONTEXT_KHR 0x2007
/* Additional cl_context_properties */
#define CL_GL_CONTEXT_KHR 0x2008
#define CL_EGL_DISPLAY_KHR 0x2009
#define CL_GLX_DISPLAY_KHR 0x200A
#define CL_WGL_HDC_KHR 0x200B
#define CL_CGL_SHAREGROUP_KHR 0x200C
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLContextInfoKHR(const cl_context_properties * /* properties */,
cl_gl_context_info /* param_name */,
size_t /* param_value_size */,
void * /* param_value */,
size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetGLContextInfoKHR_fn)(
const cl_context_properties * properties,
cl_gl_context_info param_name,
size_t param_value_size,
void * param_value,
size_t * param_value_size_ret);
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_GL_H */

74
include/triton/external/CL/cl_gl_ext.h vendored Normal file
View File

@@ -0,0 +1,74 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
/* cl_gl_ext.h contains vendor (non-KHR) OpenCL extensions which have */
/* OpenGL dependencies. */
#ifndef __OPENCL_CL_GL_EXT_H
#define __OPENCL_CL_GL_EXT_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl_gl.h>
#else
#include "cl_gl.h"
#endif
/*
* For each extension, follow this template
* cl_VEN_extname extension */
/* #define cl_VEN_extname 1
* ... define new types, if any
* ... define new tokens, if any
* ... define new APIs, if any
*
* If you need GLtypes here, mirror them with a cl_GLtype, rather than including a GL header
* This allows us to avoid having to decide whether to include GL headers or GLES here.
*/
/*
* cl_khr_gl_event extension
* See section 9.9 in the OpenCL 1.1 spec for more information
*/
#define CL_COMMAND_GL_FENCE_SYNC_OBJECT_KHR 0x200D
extern CL_API_ENTRY cl_event CL_API_CALL
clCreateEventFromGLsyncKHR(cl_context /* context */,
cl_GLsync /* cl_GLsync */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_GL_EXT_H */

1458
include/triton/external/CL/cl_platform.h vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
/**********************************************************************************
* Copyright (c) 2008-2016 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2016 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_va_api_media_sharing_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#include "cl.h"
#include "cl_platform.h"
#include <va/va.h>
#ifdef __cplusplus
extern "C" {
#endif
/******************************************
* cl_intel_va_api_media_sharing extension *
*******************************************/
#define cl_intel_va_api_media_sharing 1
/* error codes */
#define CL_INVALID_VA_API_MEDIA_ADAPTER_INTEL -1098
#define CL_INVALID_VA_API_MEDIA_SURFACE_INTEL -1099
#define CL_VA_API_MEDIA_SURFACE_ALREADY_ACQUIRED_INTEL -1100
#define CL_VA_API_MEDIA_SURFACE_NOT_ACQUIRED_INTEL -1101
/* cl_va_api_device_source_intel */
#define CL_VA_API_DISPLAY_INTEL 0x4094
/* cl_va_api_device_set_intel */
#define CL_PREFERRED_DEVICES_FOR_VA_API_INTEL 0x4095
#define CL_ALL_DEVICES_FOR_VA_API_INTEL 0x4096
/* cl_context_info */
#define CL_CONTEXT_VA_API_DISPLAY_INTEL 0x4097
/* cl_mem_info */
#define CL_MEM_VA_API_MEDIA_SURFACE_INTEL 0x4098
/* cl_image_info */
#define CL_IMAGE_VA_API_PLANE_INTEL 0x4099
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_VA_API_MEDIA_SURFACES_INTEL 0x409A
#define CL_COMMAND_RELEASE_VA_API_MEDIA_SURFACES_INTEL 0x409B
typedef cl_uint cl_va_api_device_source_intel;
typedef cl_uint cl_va_api_device_set_intel;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceIDsFromVA_APIMediaAdapterINTEL(
cl_platform_id /* platform */,
cl_va_api_device_source_intel /* media_adapter_type */,
void* /* media_adapter */,
cl_va_api_device_set_intel /* media_adapter_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL * clGetDeviceIDsFromVA_APIMediaAdapterINTEL_fn)(
cl_platform_id /* platform */,
cl_va_api_device_source_intel /* media_adapter_type */,
void* /* media_adapter */,
cl_va_api_device_set_intel /* media_adapter_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromVA_APIMediaSurfaceINTEL(
cl_context /* context */,
cl_mem_flags /* flags */,
VASurfaceID* /* surface */,
cl_uint /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL * clCreateFromVA_APIMediaSurfaceINTEL_fn)(
cl_context /* context */,
cl_mem_flags /* flags */,
VASurfaceID* /* surface */,
cl_uint /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireVA_APIMediaSurfacesINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireVA_APIMediaSurfacesINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseVA_APIMediaSurfacesINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseVA_APIMediaSurfacesINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H */

59
include/triton/external/CL/opencl.h vendored Normal file
View File

@@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_H
#define __OPENCL_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <OpenCL/cl_gl.h>
#include <OpenCL/cl_gl_ext.h>
#include <OpenCL/cl_ext.h>
#else
#include "cl.h"
#include "cl_gl.h"
#include "cl_gl_ext.h"
#include "cl_ext.h"
#endif
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_H */

14503
include/triton/external/CUDA/cuda.h vendored Executable file

File diff suppressed because it is too large Load Diff

6281
include/triton/external/CUDA/nvml.h vendored Executable file

File diff suppressed because it is too large Load Diff

3067
include/triton/external/half.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
#include <string>
#include <list>
#include "value.h"
#include "visitor.h"
namespace triton{
namespace ir{
class context;
class function;
class instruction;
/* Basic Block */
class basic_block: public value{
public:
// instruction iterator types
typedef std::list<instruction*> inst_list_t;
typedef inst_list_t::iterator iterator;
typedef inst_list_t::const_iterator const_iterator;
typedef inst_list_t::reverse_iterator reverse_iterator;
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent);
public:
// accessors
function* get_parent() { return parent_; }
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
iterator get_first_non_phi();
// get instruction list
inst_list_t &get_inst_list() { return inst_list_; }
void erase(instruction *i) { inst_list_.remove(i); }
// instruction iterator functions
inline iterator begin() { return inst_list_.begin(); }
inline const_iterator begin() const { return inst_list_.begin(); }
inline iterator end () { return inst_list_.end(); }
inline const_iterator end () const { return inst_list_.end(); }
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
inline reverse_iterator rend () { return inst_list_.rend(); }
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
inline size_t size() const { return inst_list_.size(); }
inline bool empty() const { return inst_list_.empty(); }
inline const instruction &front() const { return *inst_list_.front(); }
inline instruction &front() { return *inst_list_.front(); }
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
const std::vector<basic_block*>& get_successors() const { return succs_; }
void add_predecessor(basic_block* pred);
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }
private:
context &ctx_;
std::string name_;
function *parent_;
std::vector<basic_block*> preds_;
std::vector<basic_block*> succs_;
inst_list_t inst_list_;
};
}
}
#endif

157
include/triton/ir/builder.h Normal file
View File

@@ -0,0 +1,157 @@
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
#include <vector>
#include <string>
#include "instructions.h"
#include "basic_block.h"
#include "type.h"
namespace triton{
namespace ir{
class basic_block;
class value;
class type;
class constant_int;
class instruction;
class context;
class phi_node;
/* Builder */
class builder{
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
void set_insert_point_after(instruction* i);
void set_insert_point(basic_block* block);
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int32(unsigned val);
// Types
type *get_void_ty();
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_half_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst, const std::string &name = ""){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
inst->set_name(name);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
}
// terminator instructions
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_ui(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
value *create_downcast(value *arg, const std::string &name = "");
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
value *create_fadd(value *lhs, value *rhs, const std::string &name = "");
value *create_fsub(value *lhs, value *rhs, const std::string &name = "");
value *create_mul(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_udiv(value *lhs, value *rhs, const std::string &name = "");
value *create_srem(value *lhs, value *rhs, const std::string &name = "");
value *create_urem(value *lhs, value *rhs, const std::string &name = "");
value *create_add(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpONE(value *lhs, value *rhs, const std::string &name = "");
// Logical
value *create_and(value *lhs, value *rhs, const std::string &name = "");
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
value *create_or(value *lhs, value *rhs, const std::string &name = "");
// Unary
// value *create_fneg(value *arg, const std::string &name = "");
// value *create_neg(value *arg, const std::string &name = "");
// value *create_not(value *arg, const std::string &name = "");
// Input/Output
value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = "");
value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = "");
// Tile instruction
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_program_id(unsigned axis, const std::string &name = "");
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
private:
context &ctx_;
basic_block *block_;
iterator insert_point_;
};
}
}
#endif

View File

@@ -0,0 +1,113 @@
#pragma once
#ifndef _TRITON_IR_CONSTANT_H_
#define _TRITON_IR_CONSTANT_H_
#include "enums.h"
#include "value.h"
#include <cassert>
#include "visitor.h"
namespace triton{
namespace ir{
class type;
class context;
/* Constant */
class constant: public user{
protected:
using user::user;
public:
static constant* get_all_ones_value(type *ty);
static constant* get_null_value(type *ty);
virtual std::string repr() const = 0;
};
/* Undef value */
class undef_value: public constant{
private:
undef_value(type *ty);
public:
static undef_value* get(type* ty);
std::string repr() const { return "undef"; }
void accept(visitor* vst) { vst->visit_undef_value(this); }
};
/* Constant int */
class constant_int: public constant{
protected:
constant_int(type *ty, uint64_t value);
public:
virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_int(this); }
protected:
uint64_t value_;
};
/* Constant fp */
class constant_fp: public constant{
constant_fp(type *ty, double value);
public:
double get_value() { return value_; }
static constant* get_negative_zero(type *ty);
static constant* get_zero_value_for_negation(type *ty);
static constant* get(context &ctx, double v);
static constant* get(type *ty, double v);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_fp(this); }
private:
double value_;
};
/* Global Value */
class global_value: public constant {
public:
enum linkage_types_t {
external
};
public:
global_value(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space);
std::string repr() const { return get_name(); }
private:
linkage_types_t linkage_;
};
/* global object */
class global_object: public global_value {
public:
global_object(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space = 0);
std::string repr() const { return get_name(); }
};
/* global variable */
class alloc_const: public global_object {
public:
alloc_const(type *ty, constant_int *size,
const std::string &name = "");
std::string repr() const { return get_name(); }
void accept(visitor* vst) { vst->visit_alloc_const(this); }
};
}
}
#endif

View File

@@ -0,0 +1,27 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_H_
#define _TRITON_IR_CONTEXT_H_
#include <memory>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class type;
class context_impl;
/* Context */
class context {
public:
context();
public:
std::shared_ptr<context_impl> p_impl;
};
}
}
#endif

View File

@@ -0,0 +1,43 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */
class context_impl {
public:
// constructors
context_impl(context &ctx);
public:
// primitive types
type void_ty, label_ty, half_ty, float_ty, double_ty;
// derived types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
};
}
}
#endif

149
include/triton/ir/enums.h Normal file
View File

@@ -0,0 +1,149 @@
#pragma once
#ifndef _TRITON_IR_ENUMS_H_
#define _TRITON_IR_ENUMS_H_
namespace triton{
namespace ir{
enum binary_op_t {
Add,
FAdd,
Sub,
FSub,
Mul,
FMul,
UDiv,
SDiv,
FDiv,
URem,
SRem,
FRem,
Shl,
LShr,
AShr,
And,
Or,
Xor
};
enum cast_op_t {
Trunc,
ZExt,
SExt,
FPTrunc,
FPExt,
UIToFP,
SIToFP,
FPToUI,
FPToSI,
PtrToInt,
IntToPtr,
BitCast,
AddrSpaceCast
};
enum cmp_pred_t {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
FCMP_OGT,
FCMP_OGE,
FCMP_OLT,
FCMP_OLE,
FCMP_ONE,
FCMP_ORD,
FCMP_UNO,
FCMP_UEQ,
FCMP_UGT,
FCMP_UGE,
FCMP_ULT,
FCMP_ULE,
FCMP_UNE,
FCMP_TRUE,
LAST_FCMP_PREDICATE,
FIRST_ICMP_PREDICATE,
ICMP_EQ,
ICMP_NE,
ICMP_UGT,
ICMP_UGE,
ICMP_ULT,
ICMP_ULE,
ICMP_SGT,
ICMP_SGE,
ICMP_SLT,
ICMP_SLE,
LAST_ICMP_PREDICATE
};
enum value_id_t: unsigned {
/* ------------ *
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// phi
INST_PHI,
// arithmetic
INST_BINOP,
INST_GETELEMENTPTR,
INST_SELECT,
INST_SQRT,
// cmp
INST_ICMP,
INST_FCMP,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
INST_CAST_SEXT,
INST_CAST_FP_TRUNC,
INST_CAST_FP_EXT,
INST_CAST_UI_TO_FP,
INST_CAST_SI_TO_FP,
INST_CAST_FP_TO_UI,
INST_CAST_FP_TO_SI,
INST_CAST_PTR_TO_INT,
INST_CAST_INT_TO_PTR,
INST_CAST_BIT_CAST,
INST_CAST_ADDR_SPACE_CAST,
// terminators
INST_RETURN,
INST_COND_BRANCH,
INST_UNCOND_BRANCH,
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
INST_GET_PROGRAM_ID,
INST_GET_NUM_PROGRAMS,
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_ADD,
// array arithmetic
INST_TRANS,
INST_REDUCE,
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_RECOALESCE,
INST_BARRIER,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE
};
}
}
#endif

View File

@@ -0,0 +1,132 @@
#pragma once
#ifndef _TRITON_IR_FUNCTION_H_
#define _TRITON_IR_FUNCTION_H_
#include <string>
#include <map>
#include "value.h"
#include "constant.h"
namespace triton{
namespace ir{
class function;
class function_type;
class module;
class basic_block;
/* Argument */
class argument: public value{
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
void accept(visitor *v);
private:
function *parent_;
unsigned arg_no_;
};
/* Attribute */
enum attribute_kind_t {
readonly,
writeonly,
noalias,
aligned,
multiple_of
};
class attribute {
public:
attribute(attribute_kind_t kind, unsigned value = 0):
kind_(kind), value_(value){}
bool operator<(const attribute& other) const {
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
}
attribute_kind_t get_kind() const {
return kind_;
}
unsigned get_value() const {
return value_;
}
bool is_llvm_attr() const {
return kind_ != multiple_of;
}
std::string repr() const {
switch(kind_){
case readonly: return ".readonly";
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly";
default: break;
}
assert(false);
return "";
}
private:
attribute_kind_t kind_;
unsigned value_;
};
/* Function */
class function: public global_object{
typedef std::vector<argument*> args_t;
typedef args_t::iterator arg_iterator;
typedef args_t::const_iterator const_arg_iterator;
typedef std::vector<basic_block*> blocks_t;
typedef blocks_t::iterator block_iterator;
typedef blocks_t::const_iterator const_block_iterator;
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
private:
function(function_type *ty, linkage_types_t linkage,
const std::string &name = "", module *parent = nullptr);
public:
// accessors
const args_t &args() { return args_; }
function_type* get_fn_type() { return fn_ty_; }
// factory methods
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
// visitor
void accept(visitor *v) { v->visit_function(this); }
private:
module *parent_;
bool init_;
function_type *fn_ty_;
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
};
}
}
#endif

View File

@@ -0,0 +1,808 @@
#pragma once
#ifndef _TRITON_IR_INSTRUCTIONS_H_
#define _TRITON_IR_INSTRUCTIONS_H_
#include <vector>
#include <map>
#include "triton/ir/enums.h"
#include "triton/ir/constant.h"
#include "triton/ir/value.h"
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#include "triton/ir/visitor.h"
#define _TRITON_DEFINE_CLONE(name) \
ir::instruction* clone_impl() const { return new name(*this); }
#define _TRITON_DEFINE_ACCEPT(name) \
void accept(visitor* v) { v->visit_ ## name (this); }
namespace triton{
namespace ir{
class constant_int;
class constant;
class make_range;
class basic_block;
class context;
class visitor;
//===----------------------------------------------------------------------===//
// instruction classes
//===----------------------------------------------------------------------===//
class result_reference;
class instruction: public user{
public:
virtual std::string repr_impl() const = 0;
private:
virtual ir::instruction* clone_impl() const = 0;
protected:
// constructors
instruction(type *ty, value_id_t ity, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// parent
void set_parent(basic_block *block) { parent_ = block; }
const basic_block *get_parent() const { return parent_; }
basic_block *get_parent() { return parent_; }
void erase_from_parent();
// helpers
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
for(auto it = op_begin(); it != op_end(); it++)
(*it)->add_use(res);
res->parent_ = nullptr;
return res;
}
// instruction id
value_id_t get_id() const { return id_; }
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
value_id_t id_;
};
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
class phi_node: public instruction {
private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; }
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
void add_incoming(value *v, basic_block *block);
// Type
void set_type(type *ty) { ty_ = ty; }
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(phi_node)
_TRITON_DEFINE_ACCEPT(phi_node)
private:
unsigned num_reserved_;
std::vector<basic_block*> blocks_;
};
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
private:
std::string repr_impl() const;
protected:
// Constructors
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
public:
// Get operand
binary_op_t get_op() const { return op_; }
// Bool
bool is_terminator() const;
bool is_binary_op() const;
bool is_int_div_rem() const;
bool is_shift() const;
bool is_cast() const;
bool is_int_mult() const;
bool is_int_add_sub() const;
bool is_int_div() const;
bool is_int_rem() const;
bool is_shl() const;
bool is_shr() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory methods
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
_TRITON_DEFINE_ACCEPT(binary_operator)
public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
//===----------------------------------------------------------------------===//
// cmp_inst classes
//===----------------------------------------------------------------------===//
class cmp_inst: public instruction{
public:
typedef cmp_pred_t pred_t;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(cmp_pred_t pred);
static bool is_int_predicate(cmp_pred_t pred);
static type* make_cmp_result_type(type *ty);
public:
cmp_pred_t get_pred() const { return pred_; }
private:
cmp_pred_t pred_;
};
class icmp_inst: public cmp_inst {
icmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(icmp_inst)
_TRITON_DEFINE_ACCEPT(icmp_inst)
};
class fcmp_inst: public cmp_inst {
fcmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(fcmp_inst)
_TRITON_DEFINE_ACCEPT(fcmp_inst)
};
//===----------------------------------------------------------------------===//
// unary_inst classes
//===----------------------------------------------------------------------===//
class unary_inst: public instruction {
protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
class cast_inst: public unary_inst{
private:
std::string repr_impl() const;
protected:
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, id, v, name, next), op_(op) { }
private:
static bool is_valid(cast_op_t op, value *arg, type *ty);
public:
// accessors
cast_op_t get_op() const { return op_; }
// factory methods
static cast_inst *create(cast_op_t op, value *arg, type *ty,
const std::string &name = "", instruction *next = nullptr);
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_ACCEPT(cast_inst)
private:
cast_op_t op_;
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name) \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, id, v, name, next, op){ } \
};
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes
//===----------------------------------------------------------------------===//
class terminator_inst: public instruction{
using instruction::instruction;
};
// return instruction
class return_inst: public terminator_inst {
private:
std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next);
public:
// accessors
value *get_return_value()
{ return get_num_operands() ? get_operand(0) : nullptr; }
unsigned get_num_successors() const { return 0; }
// factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
_TRITON_DEFINE_CLONE(return_inst)
_TRITON_DEFINE_ACCEPT(return_inst)
};
// base branch instruction
class branch_inst: public terminator_inst{
private:
std::string repr_impl() const { return "br"; }
protected:
using terminator_inst::terminator_inst;
public:
static branch_inst* create(basic_block *dest,
instruction *next = nullptr);
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
instruction *next = nullptr);
};
// conditional branch
class cond_branch_inst: public branch_inst {
private:
friend class branch_inst;
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
value *get_cond() { return get_operand(2); }
_TRITON_DEFINE_CLONE(cond_branch_inst)
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
};
// unconditional branch
class uncond_branch_inst: public branch_inst {
private:
friend class branch_inst;
uncond_branch_inst(basic_block *dst, instruction *next);
public:
basic_block *get_dest() { return (basic_block*)get_operand(0); }
_TRITON_DEFINE_CLONE(uncond_branch_inst)
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
};
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction {
private:
std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
private:
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
public:
// accessors
type *get_source_elt_ty() { return source_elt_ty; }
op_iterator idx_begin() { return op_begin() + 1; }
op_iterator idx_end() { return op_end(); }
value *get_pointer_operand() { return *op_begin(); }
// factory methods
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(getelementptr_inst)
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
private:
type *source_elt_ty;
type *res_elt_ty;
};
//===----------------------------------------------------------------------===//
// load_inst/store_inst classes
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
};
// load
class load_inst: public io_inst {
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
private:
static type *get_pointee_type(type *ty);
};
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load"; }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
};
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
_TRITON_DEFINE_ACCEPT(masked_load_inst)
};
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
value *get_value_operand() { return get_operand(1); }
};
// unmasked_store
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
};
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst)
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
// retile
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
class reshape_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "reshape"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
_TRITON_DEFINE_ACCEPT(reshape_inst)
};
// splat
class splat_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "splat"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
_TRITON_DEFINE_ACCEPT(splat_inst)
};
// broadcast
class broadcast_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "broadcast"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
_TRITON_DEFINE_ACCEPT(broadcast_inst)
};
// downcast
class downcast_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "downcast"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(downcast_inst)
_TRITON_DEFINE_ACCEPT(downcast_inst)
};
//===----------------------------------------------------------------------===//
// builtin_inst classes
//===----------------------------------------------------------------------===//
class builtin_inst: public instruction{
protected:
using instruction::instruction;
};
class get_program_id_inst: public builtin_inst {
private:
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_program_id_inst)
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
private:
unsigned axis_;
};
class get_num_program_inst: public builtin_inst {
private:
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_program_inst)
_TRITON_DEFINE_ACCEPT(get_num_program_inst)
private:
unsigned axis_;
};
class atomic_cas_inst: public builtin_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
_TRITON_DEFINE_CLONE(atomic_cas_inst)
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_exch_inst: public builtin_inst {
private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; }
_TRITON_DEFINE_CLONE(atomic_exch_inst)
_TRITON_DEFINE_ACCEPT(atomic_exch_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_add_inst: public builtin_inst {
private:
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
};
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst)
};
//class outer_inst: public builtin_inst {
//private:
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
//public:
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
//};
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
private:
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<int> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst)
_TRITON_DEFINE_ACCEPT(trans_inst)
private:
std::vector<int> perm_;
};
class sqrt_inst: public builtin_inst {
private:
sqrt_inst(value *arg, const std::string& name, instruction* next);
std::string repr_impl() const { return "sqrt"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(sqrt_inst)
_TRITON_DEFINE_ACCEPT(sqrt_inst)
};
class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
};
private:
static type* get_res_type(value *arg, unsigned axis);
static std::string to_str(op_t op);
private:
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
_TRITON_DEFINE_ACCEPT(reduce_inst)
public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
private:
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; }
_TRITON_DEFINE_CLONE(select_inst)
_TRITON_DEFINE_ACCEPT(select_inst)
public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// intrinsics classes
//===----------------------------------------------------------------------===//
class copy_to_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_to_shared"; }
public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
};
class copy_from_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_from_shared"; }
public:
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class recoalesce_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "recoalesce_inst"; }
public:
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(recoalesce_inst)
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; }
_TRITON_DEFINE_CLONE(barrier_inst)
_TRITON_DEFINE_ACCEPT(barrier_inst)
public:
static barrier_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant
class make_range_dyn: public instruction {
private:
make_range_dyn(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
_TRITON_DEFINE_CLONE(make_range_dyn)
_TRITON_DEFINE_ACCEPT(make_range_dyn)
public:
static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
};
class make_range_sta: public constant {
private:
make_range_sta(make_range *range);
public:
static make_range_sta *get(make_range* range);
make_range* get_range() const;
std::string repr() const { return "nv_static_program_idx"; }
_TRITON_DEFINE_ACCEPT(make_range_sta)
private:
make_range *range_;
};
/* constant range */
class make_range: public instruction{
make_range(type *ty, constant_int* first, constant_int* last);
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
_TRITON_DEFINE_CLONE(make_range)
_TRITON_DEFINE_ACCEPT(make_range)
public:
static make_range *create(constant_int *first, constant_int *last);
const constant_int* get_first() const;
const constant_int* get_last() const;
private:
constant_int* first_;
constant_int* last_;
};
}
}
#endif

View File

@@ -0,0 +1,31 @@
#pragma once
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
namespace triton{
namespace ir{
/* Metadata */
class metadata{
public:
enum kind_t{
multiple_of
};
private:
metadata(kind_t kind, unsigned value);
public:
static metadata* get(kind_t kind, unsigned value);
private:
kind_t kind_;
unsigned value_;
};
}
}
#endif

117
include/triton/ir/module.h Normal file
View File

@@ -0,0 +1,117 @@
#pragma once
#ifndef _TRITON_IR_MODULE_H_
#define _TRITON_IR_MODULE_H_
#include <map>
#include <set>
#include <stack>
#include <string>
#include <functional>
#include "builder.h"
#include "metadata.h"
namespace triton{
namespace lang{
class iteration_statement;
class compound_statement;
}
namespace ir{
class basic_block;
class phi_node;
class value;
class context;
class function;
class attribute;
class function_type;
class constant;
class global_value;
class alloc_const;
/* Module */
struct scope {
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, context &ctx);
context& get_context();
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty);
// Scope
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
void pop_scope() { scopes_.pop(); }
scope& get_scope() { return scopes_.top(); }
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
private:
std::string name_;
context &context_;
builder builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::stack<scope> scopes_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
};
}
}
#endif

18
include/triton/ir/print.h Normal file
View File

@@ -0,0 +1,18 @@
#pragma once
#ifndef _TRITON_IR_PRINT_H_
#define _TRITON_IR_PRINT_H_
#include "builder.h"
namespace triton{
namespace ir{
class module;
void print(module &mod, std::ostream& os);
}
}
#endif

238
include/triton/ir/type.h Normal file
View File

@@ -0,0 +1,238 @@
#pragma once
#ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_
#include <cassert>
#include <vector>
#include <string>
namespace triton{
namespace ir{
class context;
class value;
class integer_type;
class constant_int;
/* Type */
class type {
public:
typedef std::vector<unsigned> tile_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
public:
enum id_t {
// primitive types
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
FloatTyID, ///< 2: 32-bit floating point type
DoubleTyID, ///< 3: 64-bit floating point type
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 7: Labels
MetadataTyID, ///< 8: Metadata
TokenTyID, ///< 9: Token
// derived types
IntegerTyID, ///< 10: Arbitrary bit width integers
FunctionTyID, ///< 11: Functions
PointerTyID, ///< 12: Pointers
StructTyID, ///< 13: Struct
TileTyID, ///< 14: Tile
};
public:
//constructors
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
//destructor
virtual ~type(){}
// accessors
context &get_context() const { return ctx_; }
id_t get_type_id() const { return id_; }
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
const tile_shapes_t& get_tile_shapes() const;
const size_t get_tile_rank() const;
const size_t get_tile_ranks1() const;
unsigned get_tile_num_elements() const;
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const;
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
bool is_integer_ty(unsigned width) const;
bool is_floating_point_ty() const;
bool is_sized() const ;
// Factory methods
// primitive types
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_half_ty(context &ctx);
static type *get_float_ty(context &ctx);
static type *get_double_ty(context &ctx);
// integer types
static integer_type *get_int1_ty(context &ctx);
static integer_type *get_int8_ty(context &ctx);
static integer_type *get_int16_ty(context &ctx);
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
// repr
std::string tile_repr() const {
std::string res = get_tile_element_ty()->repr();
auto shapes = get_tile_shapes();
res += "<";
for(size_t i = 0; i < shapes.size(); i++){
if(i > 0)
res += ", ";
res += std::to_string(shapes[i]);
}
res+= ">";
return res;
}
std::string repr() const {
switch(id_) {
case VoidTyID: return "void";
case HalfTyID: return "f16";
case FloatTyID: return "f32";
case DoubleTyID: return "f64";
case X86_FP80TyID: return "f80";
case FP128TyID: return "f128";
case PPC_FP128TyID: return "ppcf128";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case TileTyID: return tile_repr();
default: break;
}
assert(false);
return "";
};
private:
context &ctx_;
id_t id_;
protected:
contained_tys_vec_t contained_tys_;
};
class integer_type: public type {
friend class context_impl;
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
};
class composite_type: public type{
protected:
using type::type;
public:
bool index_valid(value *idx) const;
type* get_type_at_index(value *idx) const;
};
class tile_type: public composite_type {
private:
tile_type(type *ty, const tile_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const tile_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const;
unsigned get_bitwidth() const;
// factory methods
static tile_type* get(type *ty, const tile_shapes_t &shapes);
static tile_type* get_same_shapes(type *ty, type *ref);
private:
tile_shapes_t shapes_;
};
class pointer_type: public type {
private:
pointer_type(type *ty, unsigned address_space);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
unsigned get_address_space() const { return address_space_; }
type *get_element_ty() const { return contained_tys_[0]; }
// factory methods
static pointer_type* get(type *ty, unsigned address_space);
private:
unsigned address_space_;
};
class function_type: public type {
private:
function_type(type *ret_ty, const std::vector<type *> &param_tys);
public:
// accessors
unsigned get_num_params() const { return contained_tys_.size() - 1; }
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
const_ty_iterator params_end() const { return contained_tys_.end(); }
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); }
// factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
};
}
}
#endif

29
include/triton/ir/utils.h Normal file
View File

@@ -0,0 +1,29 @@
#pragma once
#ifndef _TRITON_IR_CFG_H_
#define _TRITON_IR_CFG_H_
#include <vector>
#include <functional>
namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
class value;
class cfg {
public:
static std::vector<basic_block *> reverse_post_order(function* fn);
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}
}
#endif

90
include/triton/ir/value.h Normal file
View File

@@ -0,0 +1,90 @@
#pragma once
#ifndef _TRITON_IR_VALUE_H_
#define _TRITON_IR_VALUE_H_
#include <string>
#include <vector>
#include <set>
namespace triton{
namespace ir{
class type;
class use;
class user;
class visitor;
//===----------------------------------------------------------------------===//
// value class
//===----------------------------------------------------------------------===//
class value {
public:
// constructor
value(type *ty, const std::string &name = "");
virtual ~value(){ }
// uses
void add_use(user* arg);
unsigned erase_use(user* arg);
const std::set<user*> &get_users() { return users_; }
virtual void replace_all_uses_with(value *target);
// name
void set_name(const std::string &name);
const std::string &get_name() const { return name_; }
type* get_type() const { return ty_; }
// visitor
virtual void accept(visitor *v) = 0;
private:
std::string name_;
protected:
type *ty_;
std::set<user*> users_;
};
//===----------------------------------------------------------------------===//
// user class
//===----------------------------------------------------------------------===//
class user: public value{
public:
typedef std::vector<value*> ops_t;
typedef ops_t::iterator op_iterator;
typedef ops_t::const_iterator const_op_iterator;
protected:
void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; }
void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; }
public:
// Constructor
user(type *ty, unsigned num_ops, const std::string &name = "")
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
}
// Operands
const ops_t& ops() { return ops_; }
op_iterator op_begin() { return ops_.begin(); }
op_iterator op_end() { return ops_.end(); }
void set_operand(unsigned i, value *x);
value *get_operand(unsigned i) const;
unsigned get_num_operands() const ;
unsigned get_num_hidden() const;
// Utils
void replace_all_uses_with(value *target);
void replace_uses_of_with(value *before, value *after);
private:
ops_t ops_;
unsigned num_ops_;
unsigned num_hidden_;
};
}
}
#endif

152
include/triton/ir/visitor.h Normal file
View File

@@ -0,0 +1,152 @@
#pragma once
#ifndef _TRITON_IR_VISITOR_H_
#define _TRITON_IR_VISITOR_H_
namespace triton{
namespace ir{
class value;
class instruction;
class phi_node;
class binary_operator;
class getelementptr_inst;
class icmp_inst;
class fcmp_inst;
class cast_inst;
class trunc_inst;
class z_ext_inst;
class s_ext_inst;
class fp_trunc_inst;
class fp_ext_inst;
class ui_to_fp_inst;
class si_to_fp_inst;
class fp_to_ui_inst;
class fp_to_si_inst;
class ptr_to_int_inst;
class int_to_ptr_inst;
class bit_cast_inst;
class addr_space_cast_inst;
class return_inst;
class cond_branch_inst;
class uncond_branch_inst;
class unmasked_load_inst;
class masked_load_inst;
class unmasked_store_inst;
class masked_store_inst;
class retile_inst;
class reshape_inst;
class splat_inst;
class broadcast_inst;
class downcast_inst;
class get_program_id_inst;
class get_num_program_inst;
class atomic_cas_inst;
class atomic_exch_inst;
class atomic_add_inst;
class dot_inst;
class trans_inst;
class sqrt_inst;
class reduce_inst;
class select_inst;
class recoalesce_inst;
class copy_to_shared_inst;
class copy_from_shared_inst;
class barrier_inst;
class make_range_dyn;
class make_range;
class make_range_sta;
class undef_value;
class constant_int;
class constant_fp;
class global_value;
class global_object;
class alloc_const;
class constant_fp;
class undef_value;
class constant_int;
class constant_fp;
class global_value;
class global_object;
class alloc_const;
class function;
class basic_block;
class argument;
class visitor {
public:
virtual ~visitor() {}
virtual void visit_value(ir::value*);
virtual void visit_basic_block(basic_block*) = 0;
virtual void visit_argument(argument*) = 0;
virtual void visit_phi_node(phi_node*) = 0;
virtual void visit_binary_operator(binary_operator*) = 0;
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;
virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0;
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0;
virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0;
virtual void visit_masked_load_inst(masked_load_inst*) = 0;
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0;
virtual void visit_splat_inst(splat_inst*) = 0;
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
virtual void visit_downcast_inst(downcast_inst*) = 0;
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
virtual void visit_get_num_program_inst(get_num_program_inst*) = 0;
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
virtual void visit_dot_inst(dot_inst*) = 0;
virtual void visit_trans_inst(trans_inst*) = 0;
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
virtual void visit_reduce_inst(reduce_inst*) = 0;
virtual void visit_select_inst(select_inst*) = 0;
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_make_range_sta(make_range_sta*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0;
};
}
}
#endif

820
include/triton/lang/ast.h Normal file
View File

@@ -0,0 +1,820 @@
#pragma once
#ifndef _WGTCC_AST_H_
#define _WGTCC_AST_H_
#include "error.h"
#include "token.h"
#include "type.h"
#include <cassert>
#include <list>
#include <memory>
#include <string>
class Visitor;
template<typename T> class Evaluator;
class AddrEvaluator;
class Generator;
class Scope;
class Parser;
class ASTNode;
class Token;
class TokenSequence;
// Expressions
class Expr;
class BinaryOp;
class UnaryOp;
class ConditionalOp;
class FuncCall;
class TempVar;
class Constant;
class Identifier;
class Object;
struct Initializer;
class Declaration;
class Enumerator;
// Statements
class Stmt;
class IfStmt;
class ForStmt;
class JumpStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
/*
* AST Node
*/
class ASTNode {
public:
struct Attr{
enum KindT{
MULTIPLEOF,
ALIGNED,
NOALIAS,
READONLY,
WRITEONLY
};
KindT kind;
std::vector<Expr*> vals;
};
using AttrList = std::vector<Attr>;
public:
virtual ~ASTNode() {}
virtual void Accept(Visitor* v) = 0;
protected:
ASTNode() {}
MemPool* pool_ {nullptr};
};
using ExtDecl = ASTNode;
/*
* Statements
*/
class Stmt : public ASTNode {
public:
virtual ~Stmt() {}
protected:
Stmt() {}
};
class EmptyStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static EmptyStmt* New();
virtual ~EmptyStmt() {}
virtual void Accept(Visitor* v);
protected:
EmptyStmt() {}
};
class LabelStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static LabelStmt* New();
~LabelStmt() {}
virtual void Accept(Visitor* v);
std::string Repr() const { return ".L" + std::to_string(tag_); }
protected:
LabelStmt(): tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_; // 使用整型的tag值而不直接用字符串
};
class IfStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr);
virtual ~IfStmt() {}
virtual void Accept(Visitor* v);
protected:
IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr)
: cond_(cond), then_(then), else_(els) {}
private:
Expr* cond_;
Stmt* then_;
Stmt* else_;
};
class ForStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
virtual ~ForStmt() {}
virtual void Accept(Visitor* v);
protected:
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
: body_(body), init_(init), cond_(cond), step_(step) {}
private:
Stmt* body_;
Stmt* init_;
Expr* cond_;
Expr* step_;
};
class JumpStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static JumpStmt* New(LabelStmt* label);
virtual ~JumpStmt() {}
virtual void Accept(Visitor* v);
void SetLabel(LabelStmt* label) { label_ = label; }
protected:
JumpStmt(LabelStmt* label): label_(label) {}
private:
LabelStmt* label_;
};
class ReturnStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ReturnStmt* New(Expr* expr);
virtual ~ReturnStmt() {}
virtual void Accept(Visitor* v);
protected:
ReturnStmt(::Expr* expr): expr_(expr) {}
private:
::Expr* expr_;
};
using StmtList = std::list<Stmt*>;
class CompoundStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr);
virtual ~CompoundStmt() {}
virtual void Accept(Visitor* v);
StmtList& Stmts() { return stmts_; }
::Scope* Scope() { return scope_; }
protected:
CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr)
: stmts_(stmts), scope_(scope) {}
private:
StmtList stmts_;
::Scope* scope_;
};
struct Initializer {
Initializer(Type* type,
int offset,
Expr* expr,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0)
: type_(type),
offset_(offset),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
expr_(expr) {}
bool operator<(const Initializer& rhs) const;
// It could be the object it self or, it will be the member
// that was initialized
Type* type_;
int offset_;
unsigned char bitFieldBegin_;
unsigned char bitFieldWidth_;
Expr* expr_;
};
using InitList = std::set<Initializer>;
class Declaration: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Declaration* New(Object* obj);
virtual ~Declaration() {}
virtual void Accept(Visitor* v);
InitList& Inits() { return inits_; }
Object* Obj() { return obj_; }
void AddInit(Initializer init);
protected:
Declaration(Object* obj): obj_(obj) {}
Object* obj_;
InitList inits_;
};
/*
* Expr
* BinaryOp
* UnaryOp
* ConditionalOp
* FuncCall
* Constant
* Identifier
* Object
* TempVar
*/
class Expr : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
virtual ~Expr() {}
::Type* Type() { return type_.GetPtr(); }
virtual bool IsLVal() = 0;
virtual void TypeChecking() = 0;
void EnsureCompatible(const QualType lhs, const QualType rhs) const;
void EnsureCompatibleOrVoidPointer(const QualType lhs,
const QualType rhs) const;
const Token* Tok() const { return tok_; }
void SetTok(const Token* tok) { tok_ = tok; }
static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType);
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
bool IsVolatileQualified() const { return type_.IsVolatileQualified(); }
protected:
// You can construct a expression without specifying a type,
// then the type should be evaluated in TypeChecking()
Expr(const Token* tok, QualType type): tok_(tok), type_(type) {}
const Token* tok_;
QualType type_;
};
/*
* '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^'
* '=',(复合赋值运算符被拆分为两个运算)
* '==', '!=', '<=', '>=',
* '&&', '||'
* '['(下标运算符), '.'(成员运算符)
* ','(逗号运算符),
*/
class BinaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
friend class Declaration;
public:
static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs);
static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs);
virtual ~BinaryOp() {}
virtual void Accept(Visitor* v);
// Member ref operator is a lvalue
virtual bool IsLVal() {
switch (op_) {
case '.': return !Type()->ToArray() && lhs_->IsLVal();
case ']': return !Type()->ToArray();
case Token::MASKED_DEREF: return true;
default: return false;
}
}
ArithmType* Convert();
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
virtual void TypeChecking();
void SubScriptingOpTypeChecking();
void MemberRefOpTypeChecking();
void MultiOpTypeChecking();
void AdditiveOpTypeChecking();
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void MaskedDerefOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();
void LogicalOpTypeChecking();
void AssignOpTypeChecking();
void CommaOpTypeChecking();
protected:
BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs)
: Expr(tok, nullptr), op_(op) {
lhs_ = lhs, rhs_ = rhs;
if (op != '.') {
lhs_ = MayCast(lhs);
rhs_ = MayCast(rhs);
}
}
int op_;
Expr* lhs_;
Expr* rhs_;
};
/*
* Unary Operator:
* '++' (prefix/postfix)
* '--' (prefix/postfix)
* '&' (ADDR)
* '*' (DEREF)
* '+' (PLUS)
* '-' (MINUS)
* '~'
* '!'
* CAST // like (int)3
*/
class UnaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
}
}
int op_;
int info_;
Expr* operand_;
};
class TransOp: public Expr {
friend class Generator;
public:
using PermInt = std::vector<int>;
public:
static TransOp* New(const PermInt& perm, Expr* operand);
const PermInt& getPerm() const { return perm_; }
void Accept(Visitor* v);
bool IsLVal() { return false; }
void TypeChecking();
protected:
TransOp(const PermInt& perm, Expr* operand)
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
private:
Expr* operand_;
PermInt perm_;
};
// cond ? true false
class ConditionalOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ConditionalOp* New(const Token* tok,
Expr* cond, Expr* exprTrue, Expr* exprFalse);
virtual ~ConditionalOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
ArithmType* Convert();
virtual void TypeChecking();
protected:
ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse)
: Expr(cond->Tok(), nullptr), cond_(MayCast(cond)),
exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {}
private:
Expr* cond_;
Expr* exprTrue_;
Expr* exprFalse_;
};
class FuncCall : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ArgList = std::vector<Expr*>;
public:
static FuncCall* New(Expr* designator, const ArgList& args);
~FuncCall() {}
virtual void Accept(Visitor* v);
// A function call is ofcourse not lvalue
virtual bool IsLVal() { return false; }
ArgList* Args() { return &args_; }
Expr* Designator() { return designator_; }
const std::string& Name() const { return tok_->str_; }
::FuncType* FuncType() { return designator_->Type()->ToFunc(); }
virtual void TypeChecking();
protected:
FuncCall(Expr* designator, const ArgList& args)
: Expr(designator->Tok(), nullptr),
designator_(designator), args_(args) {}
Expr* designator_;
ArgList args_;
};
class Constant: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Constant* New(const Token* tok, int tag, long val);
static Constant* New(const Token* tok, int tag, double val);
static Constant* New(const Token* tok, int tag, const std::string* val);
~Constant() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual void TypeChecking() {}
long IVal() const { return ival_; }
double FVal() const { return fval_; }
const std::string* SVal() const { return sval_; }
std::string SValRepr() const;
std::string Repr() const { return std::string(".LC") + std::to_string(id_); }
protected:
Constant(const Token* tok, QualType type, long val)
: Expr(tok, type), ival_(val) {}
Constant(const Token* tok, QualType type, double val)
: Expr(tok, type), fval_(val) {}
Constant(const Token* tok, QualType type, const std::string* val)
: Expr(tok, type), sval_(val) {}
union {
long ival_;
double fval_;
struct {
long id_;
const std::string* sval_;
};
};
};
class TempVar : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TempVar* New(QualType type);
virtual ~TempVar() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return true; }
virtual void TypeChecking() {}
protected:
TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_;
};
enum Linkage {
L_NONE,
L_EXTERNAL,
L_INTERNAL,
};
class Identifier: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Identifier* New(const Token* tok, QualType type, Linkage linkage, const AttrList& attrList={});
virtual ~Identifier() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual Object* ToObject() { return nullptr; }
virtual Enumerator* ToEnumerator() { return nullptr; }
// An identifer can be:
// object, sturct/union/enum tag, typedef name, function, label.
Identifier* ToTypeName() {
// A typename has no linkage
// And a function has external or internal linkage
if (ToObject() || ToEnumerator() || linkage_ != L_NONE)
return nullptr;
return this;
}
virtual const std::string Name() const { return tok_->str_; }
enum Linkage Linkage() const { return linkage_; }
void SetLinkage(enum Linkage linkage) { linkage_ = linkage; }
virtual void TypeChecking() {}
protected:
Identifier(const Token* tok, QualType type, enum Linkage linkage, const AttrList& attrList={})
: Expr(tok, type), linkage_(linkage), attrList_(attrList) {}
// An identifier has property linkage
enum Linkage linkage_;
AttrList attrList_;
};
class Enumerator: public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Enumerator* New(const Token* tok, int val);
virtual ~Enumerator() {}
virtual void Accept(Visitor* v);
virtual Enumerator* ToEnumerator() { return this; }
int Val() const { return cons_->IVal(); }
protected:
Enumerator(const Token* tok, int val)
: Identifier(tok, ArithmType::New(T_INT), L_NONE),
cons_(Constant::New(tok, T_INT, (long)val)) {}
Constant* cons_;
};
class Object : public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Object* New(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
static Object* NewAnony(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
~Object() {}
virtual void Accept(Visitor* v);
virtual Object* ToObject() { return this; }
virtual bool IsLVal() {
// TODO(wgtdkp): not all object is lval?
return true;
}
bool IsStatic() const {
return (Storage() & S_STATIC) || (Linkage() != L_NONE);
}
int Storage() const { return storage_; }
void SetStorage(int storage) { storage_ = storage; }
int Align() const { return align_; }
void SetAlign(int align) {
assert(align > 0);
// Allowing reduce alignment to implement __attribute__((packed))
//if (align < align_)
// Error(this, "alignment specifier cannot reduce alignment");
align_ = align;
}
int Offset() const { return offset_; }
void SetOffset(int offset) { offset_ = offset; }
Declaration* Decl() { return decl_; }
void SetDecl(Declaration* decl) { decl_ = decl; }
const AttrList& GetAttrList() const { return attrList_; }
unsigned char BitFieldBegin() const { return bitFieldBegin_; }
unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; }
unsigned char BitFieldWidth() const { return bitFieldWidth_; }
static unsigned long BitFieldMask(Object* bitField) {
return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_);
}
static unsigned long BitFieldMask(unsigned char begin, unsigned char width) {
auto end = begin + width;
return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin;
}
bool HasInit() const { return decl_ && decl_->Inits().size(); }
bool Anonymous() const { return anonymous_; }
virtual const std::string Name() const { return Identifier::Name(); }
std::string Repr() const {
assert(IsStatic() || anonymous_);
if (anonymous_)
return "anonymous." + std::to_string(id_);
if (linkage_ == L_NONE)
return Name() + "." + std::to_string(id_);
return Name();
}
protected:
Object(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={})
: Identifier(tok, type, linkage),
storage_(storage),
offset_(0),
align_(type->Align()),
decl_(nullptr),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
anonymous_(false),
attrList_(attrList){}
private:
int storage_;
int offset_;
int align_;
Declaration* decl_;
unsigned char bitFieldBegin_;
// 0 means it's not a bitfield
unsigned char bitFieldWidth_;
bool anonymous_;
long id_ {0};
AttrList attrList_;
};
/*
* Declaration
*/
class FuncDef : public ExtDecl {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ParamList = std::vector<Object*>;
public:
static FuncDef* New(Identifier* ident, LabelStmt* retLabel);
virtual ~FuncDef() {}
virtual void Accept(Visitor* v);
::FuncType* FuncType() { return ident_->Type()->ToFunc(); }
CompoundStmt* Body() { return body_; }
void SetBody(CompoundStmt* body) { body_ = body; }
std::string Name() const { return ident_->Name(); }
enum Linkage Linkage() { return ident_->Linkage(); }
protected:
FuncDef(Identifier* ident, LabelStmt* retLabel)
: ident_(ident), retLabel_(retLabel) {}
private:
Identifier* ident_;
LabelStmt* retLabel_;
CompoundStmt* body_;
};
using ExtDeclList = std::list<ExtDecl*>;
class TranslationUnit : public ASTNode {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TranslationUnit* New() { return new TranslationUnit();}
virtual ~TranslationUnit() {}
virtual void Accept(Visitor* v);
void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); }
ExtDeclList& ExtDecls() { return extDecls_; }
const ExtDeclList& ExtDecls() const { return extDecls_; }
private:
TranslationUnit() {}
ExtDeclList extDecls_;
};
#endif

View File

@@ -0,0 +1,162 @@
#pragma once
#ifndef _WGTCC_CODE_GEN_H_
#define _WGTCC_CODE_GEN_H_
#include "ast.h"
#include "visitor.h"
#include <stack>
namespace triton{
namespace ir{
class value;
class module;
class type;
class context;
class builder;
class attribute;
}
}
using namespace triton;
class Parser;
struct Addr;
template<> class Evaluator<Addr>;
struct StaticInitializer;
class LValAssigner;
using TypeList = std::vector<Type*>;
using LocationList = std::vector<std::string>;
using StaticInitList = std::vector<StaticInitializer>;
// Error
inline void should_not_happen() { throw std::runtime_error("should not happen"); }
inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
class Generator: public Visitor {
friend class Evaluator<Addr>;
friend class LValAssigner;
protected:
struct scope {
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
void set_ret(ir::value* value);
ir::value *GenUnaryMinus(ir::value* arg);
public:
Generator(Parser* parser) : parser_(parser) {}
void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitTransOp(TransOp* transOp);
void VisitConditionalOp(ConditionalOp* condOp);
void VisitFuncCall(FuncCall* funcCall);
void VisitObject(Object* obj);
void VisitEnumerator(Enumerator* enumer);
void VisitIdentifier(Identifier* ident);
void VisitConstant(Constant* cons);
void VisitTempVar(TempVar* tempVar);
// Statement
void VisitDeclaration(Declaration* init);
void VisitEmptyStmt(EmptyStmt* emptyStmt);
void VisitIfStmt(IfStmt* ifStmt);
void VisitForStmt(ForStmt* ifStmt);
void VisitJumpStmt(JumpStmt* jumpStmt);
void VisitReturnStmt(ReturnStmt* returnStmt);
void VisitLabelStmt(LabelStmt* labelStmt);
void VisitCompoundStmt(CompoundStmt* compoundStmt);
void VisitFuncDef(FuncDef* funcDef);
void VisitTranslationUnit(TranslationUnit* unit);
void Gen(ir::module *mod);
protected:
// Triton-IR attributes
ir::attribute GenIRAttr(ASTNode::Attr attr);
// Triton-IR values
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
ir::value* GenCastOp(ir::value* op, ir::type* type);
// Triton-IR types
static ir::type* GenIRType(::Type* type, ir::context &ctx);
static ir::type* GenIRArithmType(ArithmType* type, ir::context& ctx);
static ir::type* GenIRArrayType(ArrayType* type, ir::context& ctx);
static ir::type* GenIRTileType(TileType* type, ir::context& ctx);
static ir::type* GenIRFuncType(FuncType* type, ir::context& ctx);
static ir::type* GenIRPointerType(PointerType* type, ir::context& ctx);
static ir::type* GenIRStructType(StructType* type, ir::context& ctx);
void AllocObjects(Scope* scope, const FuncDef::ParamList& params=FuncDef::ParamList());
// SSA
void pushScope();
void popScope();
private:
Parser* parser_;
ir::value* ret_;
ir::builder* bld_;
ir::context* ctx_;
ir::module* mod_;
private:
// std::stack<scope> scopes_;
LValAssigner* assign_;
};
class LValAssigner: public Visitor {
public:
LValAssigner(Generator* gen): gen_(gen) {}
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitObject(Object* obj);
void VisitIdentifier(Identifier* ident);
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
void VisitFuncCall(FuncCall*) { should_not_happen(); }
void VisitTransOp(TransOp*) { should_not_happen(); }
void VisitEnumerator(Enumerator*) { should_not_happen(); }
void VisitConstant(Constant*) { should_not_happen(); }
void VisitTempVar(TempVar*) { should_not_happen(); }
void VisitDeclaration(Declaration*) { should_not_happen(); }
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
void VisitIfStmt(IfStmt*) { should_not_happen(); }
void VisitForStmt(ForStmt*) { should_not_happen(); }
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }
void VisitCompoundStmt(CompoundStmt*) { should_not_happen(); }
void VisitFuncDef(FuncDef*) { should_not_happen(); }
void VisitTranslationUnit(TranslationUnit*) { should_not_happen(); }
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
rhs_ = rhs;
expr->Accept(this);
return ret_;
}
private:
ir::value* ret_;
ir::value* rhs_;
Generator* gen_;
};
#endif

164
include/triton/lang/cpp.h Normal file
View File

@@ -0,0 +1,164 @@
#pragma once
#ifndef _WGTCC_CPP_H_
#define _WGTCC_CPP_H_
#include "scanner.h"
#include <cstdio>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
class Macro;
struct CondDirective;
using MacroMap = std::map<std::string, Macro>;
using ParamList = std::list<std::string>;
using ParamMap = std::map<std::string, TokenSequence>;
using PPCondStack = std::stack<CondDirective>;
using PathList = std::list<std::string>;
class Macro {
public:
Macro(const TokenSequence& repSeq, bool preDef=false)
: funcLike_(false), variadic_(false),
preDef_(preDef), repSeq_(repSeq) {}
Macro(bool variadic, ParamList& params,
TokenSequence& repSeq, bool preDef=false)
: funcLike_(true), variadic_(variadic), preDef_(preDef),
params_(params), repSeq_(repSeq) {}
~Macro() {}
bool FuncLike() { return funcLike_; }
bool ObjLike() { return !FuncLike(); }
bool Variadic() { return variadic_; }
bool PreDef() { return preDef_; }
ParamList& Params() { return params_; }
TokenSequence RepSeq(const std::string* filename, unsigned line);
private:
bool funcLike_;
bool variadic_;
bool preDef_;
ParamList params_;
TokenSequence repSeq_;
};
struct CondDirective {
int tag_;
bool enabled_;
bool cond_;
};
class Preprocessor {
public:
Preprocessor(const std::string* str, bool isSrc = true)
: curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) {
if(isSrc)
fSrc_ = str;
else
fName_ = str;
// Add predefined
Init();
}
~Preprocessor() {}
void Finalize(TokenSequence os);
void Process(TokenSequence& os);
void Expand(TokenSequence& os, TokenSequence is, bool inCond=false);
void Subst(TokenSequence& os, TokenSequence is,
bool leadingWS, const HideSet& hs, ParamMap& params);
void Glue(TokenSequence& os, TokenSequence is);
void Glue(TokenSequence& os, const Token* tok);
const Token* Stringize(TokenSequence is);
void Stringize(std::string& str, TokenSequence is);
const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap);
int GetDirective(TokenSequence& is);
const Token* EvalDefOp(TokenSequence& is);
void ReplaceIdent(TokenSequence& is);
void ParseDirective(TokenSequence& os, TokenSequence& is, int directive);
void ParseIf(TokenSequence ls);
void ParseIfdef(TokenSequence ls);
void ParseIfndef(TokenSequence ls);
void ParseElif(TokenSequence ls);
void ParseElse(TokenSequence ls);
void ParseEndif(TokenSequence ls);
void ParseInclude(TokenSequence& is, TokenSequence ls);
void ParseDef(TokenSequence ls);
void ParseUndef(TokenSequence ls);
void ParseLine(TokenSequence ls);
void ParseError(TokenSequence ls);
void ParsePragma(TokenSequence ls);
void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename);
void IncludeFile(TokenSequence& is, const std::string* filename);
bool ParseIdentList(ParamList& params, TokenSequence& is);
Macro* FindMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return nullptr;
return &res->second;
}
void AddMacro(const std::string& name,
std::string* text, bool preDef=false);
void AddMacro(const std::string& name, const Macro& macro) {
auto res = macroMap_.find(name);
if (res != macroMap_.end()) {
// TODO(wgtdkp): give warning
macroMap_.erase(res);
}
macroMap_.insert(std::make_pair(name, macro));
}
void RemoveMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return;
if(res->second.PreDef()) // Cannot undef predefined macro
return;
macroMap_.erase(res);
}
std::string* SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath);
void AddSearchPath(std::string path);
void HandleTheFileMacro(TokenSequence& os, const Token* macro);
void HandleTheLineMacro(TokenSequence& os, const Token* macro);
void UpdateFirstTokenLine(TokenSequence ts);
bool NeedExpand() const {
if (ppCondStack_.empty())
return true;
auto top = ppCondStack_.top();
return top.enabled_ && top.cond_;
}
private:
void Init();
PPCondStack ppCondStack_;
unsigned curLine_;
unsigned lineLine_;
bool curCond_;
MacroMap macroMap_;
PathList searchPaths_;
const std::string* fName_;
const std::string* fSrc_;
};
#endif

View File

@@ -0,0 +1,22 @@
#pragma once
#ifndef _WGTCC_ENCODING_H_
#define _WGTCC_ENCODING_H_
#include <string>
enum class Encoding {
NONE,
CHAR16,
CHAR32,
UTF8,
WCHAR
};
void ConvertToUTF16(std::string& str);
void ConvertToUTF32(std::string& str);
void AppendUCN(std::string& str, int c);
#endif

View File

@@ -0,0 +1,17 @@
#pragma once
#ifndef _WGTCC_ERROR_H_
#define _WGTCC_ERROR_H_
struct SourceLocation;
class Token;
class Expr;
[[noreturn]] void Error(const char* format, ...);
[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...);
[[noreturn]] void Error(const Token* tok, const char* format, ...);
[[noreturn]] void Error(const Expr* expr, const char* format, ...);
#endif

View File

@@ -0,0 +1,130 @@
#pragma once
#ifndef _WGTCC_EVALUATOR_H_
#define _WGTCC_EVALUATOR_H_
#include "ast.h"
#include "error.h"
#include "visitor.h"
class Expr;
template<typename T>
class Evaluator: public Visitor {
public:
Evaluator() {}
virtual ~Evaluator() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
val_ = static_cast<T>(enumer->Val());
}
virtual void VisitIdentifier(Identifier* ident) {
Error(ident, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitObject(Object* obj) {
Error(obj, "expect constant expression");
}
virtual void VisitConstant(Constant* cons) {
if (cons->Type()->IsFloat()) {
val_ = static_cast<T>(cons->FVal());
} else if (cons->Type()->IsInteger()) {
val_ = static_cast<T>(cons->IVal());
} else {
assert(false);
}
}
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
T Eval(Expr* expr) {
expr->Accept(this);
return val_;
}
private:
T val_;
};
struct Addr {
std::string label_;
int offset_;
};
template<>
class Evaluator<Addr>: public Visitor {
public:
Evaluator<Addr>() {}
virtual ~Evaluator<Addr>() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
addr_.offset_ = enumer->Val();
}
virtual void VisitIdentifier(Identifier* ident) {
addr_.label_ = ident->Name();
addr_.offset_ = 0;
}
virtual void VisitObject(Object* obj) {
if (!obj->IsStatic()) {
Error(obj, "expect static object");
}
addr_.label_ = obj->Repr();
addr_.offset_ = 0;
}
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
Addr Eval(Expr* expr) {
expr->Accept(this);
return addr_;
}
private:
Addr addr_;
};
#endif

View File

@@ -0,0 +1,103 @@
#pragma once
#ifndef _WGTCC_MEM_POOL_H_
#define _WGTCC_MEM_POOL_H_
#include <cstddef>
#include <vector>
class MemPool {
public:
MemPool(): allocated_(0) {}
virtual ~MemPool() {}
MemPool(const MemPool& other) = delete;
MemPool& operator=(const MemPool& other) = delete;
virtual void* Alloc() = 0;
virtual void Free(void* addr) = 0;
virtual void Clear() = 0;
protected:
size_t allocated_;
};
template <class T>
class MemPoolImp: public MemPool {
public:
MemPoolImp() : root_(nullptr) {}
virtual ~MemPoolImp() {}
MemPoolImp(const MemPool& other) = delete;
MemPoolImp& operator=(MemPool& other) = delete;
virtual void* Alloc();
virtual void Free(void* addr);
virtual void Clear();
private:
enum {
COUNT = (4 * 1024) / sizeof(T)
};
union Chunk {
Chunk* next_;
char mem_[sizeof(T)];
};
struct Block {
Block() {
for (size_t i = 0; i < COUNT - 1; ++i)
chunks_[i].next_ = &chunks_[i+1];
chunks_[COUNT-1].next_ = nullptr;
}
Chunk chunks_[COUNT];
};
std::vector<Block*> blocks_;
Chunk* root_;
};
template <class T>
void* MemPoolImp<T>::Alloc() {
if (nullptr == root_) { // 空间不够,需要分配空间
auto block = new Block();
root_ = block->chunks_;
// 如果blocks实现为std::list, 那么push_back实际的overhead更大
// 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead)
// 仍然倾向于使用std::vector
// 当然std::vector的指数级capacity增长会造成内存浪费。
blocks_.push_back(block);
}
auto ret = root_;
root_ = root_->next_;
++allocated_;
return ret;
}
template <class T>
void MemPoolImp<T>::Free(void* addr) {
if (nullptr == addr)
return;
auto chunk = static_cast<Chunk*>(addr);
chunk->next_ = root_;
root_ = chunk;
--allocated_;
}
template <class T>
void MemPoolImp<T>::Clear() {
for (auto block: blocks_)
delete block;
blocks_.resize(0);
root_ = nullptr;
allocated_ = 0;
}
#endif

View File

@@ -0,0 +1,259 @@
#pragma once
#ifndef _PARSER_H_
#define _PARSER_H_
#include "ast.h"
#include "encoding.h"
#include "error.h"
#include "mem_pool.h"
#include "scope.h"
#include "token.h"
#include <cassert>
#include <memory>
#include <stack>
class Preprocessor;
struct DeclInfo {
DeclInfo(const Token* _tok,
QualType _type,
ASTNode::AttrList _attrs = {})
: tok(_tok), type(_type), attrs(_attrs) {}
const Token* tok;
QualType type;
ASTNode::AttrList attrs;
};
class Parser {
using LiteralList = std::vector<Constant*>;
using StaticObjectList = std::vector<Object*>;
using CaseLabelList = std::vector<std::pair<Constant*, LabelStmt*>>;
using LabelJumpList = std::list<std::pair<const Token*, JumpStmt*>>;
using LabelMap = std::map<std::string, LabelStmt*>;
friend class Generator;
public:
explicit Parser(TokenSequence& ts)
: unit_(TranslationUnit::New()),
ts_(ts),
externalSymbols_(new Scope(nullptr, S_BLOCK)),
errTok_(nullptr),
curScope_(new Scope(nullptr, S_FILE)),
curFunc_(nullptr),
breakDest_(nullptr),
continueDest_(nullptr),
caseLabels_(nullptr),
defaultLabel_(nullptr) {
ts_.SetParser(this);
}
~Parser() {}
Constant* ParseConstant(const Token* tok);
Constant* ParseFloat(const Token* tok);
Constant* ParseInteger(const Token* tok);
Constant* ParseCharacter(const Token* tok);
Encoding ParseLiteral(std::string& str, const Token* tok);
Constant* ConcatLiterals(const Token* tok);
Expr* ParseGeneric();
void Parse();
void ParseTranslationUnit();
FuncDef* ParseFuncDef(Identifier* ident);
// Expressions
Expr* ParseExpr();
Expr* ParsePrimaryExpr();
QualType TryCompoundLiteral();
Object* ParseCompoundLiteral(QualType type);
Expr* ParsePostfixExpr();
Expr* ParsePostfixExprTail(Expr* primExpr);
Expr* ParseSubScripting(Expr* pointer);
BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs);
UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand);
FuncCall* ParseFuncCall(Expr* caller);
Expr* ParseUnaryExpr();
Constant* ParseSizeof();
Constant* ParseAlignof();
UnaryOp* ParsePrefixIncDec(const Token* tok);
UnaryOp* ParseUnaryOp(const Token* tok, int op);
Expr* ParseDerefOp(const Token* tok);
QualType ParseTypeName();
Expr* ParseCastExpr();
Expr* ParseRangeExpr();
Expr* ParseMatmulExpr();
Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr();
Expr* ParseRelationalExpr();
Expr* ParseEqualityExpr();
Expr* ParseBitiwiseAndExpr();
Expr* ParseBitwiseXorExpr();
Expr* ParseBitwiseOrExpr();
Expr* ParseLogicalAndExpr();
Expr* ParseLogicalOrExpr();
Expr* ParseConditionalExpr();
Expr* ParseCommaExpr();
Expr* ParseAssignExpr();
// Declarations
CompoundStmt* ParseDecl();
void ParseStaticAssert();
QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec);
QualType ParseSpecQual();
int ParseAlignas();
Type* ParseStructUnionSpec(bool isStruct);
StructType* ParseStructUnionDecl(StructType* type);
void ParseBitField(StructType* structType, const Token* tok, QualType type);
Type* ParseEnumSpec();
Type* ParseEnumerator(ArithmType* type);
int ParseQual();
QualType ParsePointer(QualType typePointedTo);
DeclInfo ParseDeclarator(QualType type);
QualType ParseArrayFuncDeclarator(const Token* ident, QualType base);
int ParseArrayLength();
TileType::ShapeInt ParseTileShape();
bool ParseParamList(FuncType::ParamList& params);
Object* ParseParamDecl();
QualType ParseAbstractDeclarator(QualType type);
Identifier* ParseDirectDeclarator(QualType type,
int storageSpec,
int funcSpec,
int align);
// Initializer
void ParseInitializer(Declaration* decl,
QualType type,
int offset,
bool designated=false,
bool forceBrace=false,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
void ParseArrayInitializer(Declaration* decl,
ArrayType* type,
int offset,
bool designated);
StructType::Iterator ParseStructDesignator(StructType* type,
const std::string& name);
void ParseStructInitializer(Declaration* decl,
StructType* type,
int offset,
bool designated);
bool ParseLiteralInitializer(Declaration* init,
ArrayType* type,
int offset);
Declaration* ParseInitDeclarator(Identifier* ident);
Declaration* ParseInitDeclaratorSub(Object* obj);
// Statements
Stmt* ParseStmt();
CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr);
IfStmt* ParseIfStmt();
CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt();
ForStmt *ParseForStmt();
JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt();
ReturnStmt* ParseReturnStmt();
CompoundStmt* ParseLabelStmt(const Token* label);
CompoundStmt* ParseCaseStmt();
CompoundStmt* ParseDefaultStmt();
Identifier* ProcessDeclarator(const Token* tok,
QualType type, const ASTNode::AttrList &attrs,
int storageSpec,
int funcSpec,
int align);
// GNU extensions
ASTNode::AttrList TryAttributeSpecList();
void ParseAttributeSpec(ASTNode::AttrList &attrList);
ASTNode::Attr ParseAttribute();
bool IsTypeName(const Token* tok) const{
if (tok->IsTypeSpecQual())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
if (ident && ident->ToTypeName())
return true;
}
return false;
}
bool IsType(const Token* tok) const{
if (tok->IsDecl())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
return (ident && ident->ToTypeName());
}
return false;
}
void EnsureInteger(Expr* expr) {
if (!expr->Type()->IsInteger()) {
Error(expr, "expect integer expression");
}
}
void EnterBlock(FuncType* funcType=nullptr);
void ExitBlock() { curScope_ = curScope_->Parent(); }
void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); }
void ExitProto() { curScope_ = curScope_->Parent(); }
FuncDef* EnterFunc(Identifier* ident);
void ExitFunc();
LabelStmt* FindLabel(const std::string& label) {
auto ret = curLabels_.find(label);
if (curLabels_.end() == ret)
return nullptr;
return ret->second;
}
void AddLabel(const std::string& label, LabelStmt* labelStmt) {
assert(nullptr == FindLabel(label));
curLabels_[label] = labelStmt;
}
TranslationUnit* Unit() { return unit_; }
FuncDef* CurFunc() { return curFunc_; }
const TokenSequence& ts() const { return ts_; }
private:
static bool IsBuiltin(FuncType* type);
static bool IsBuiltin(const std::string& name);
static Identifier* GetBuiltin(const Token* tok);
static void DefineBuiltins();
static FuncType* vaStartType_;
static FuncType* vaArgType_;
// The root of the AST
TranslationUnit* unit_;
TokenSequence& ts_;
// It is not the real scope,
// It contains all external symbols(resolved and not resolved)
Scope* externalSymbols_;
const Token* errTok_;
Scope* curScope_;
FuncDef* curFunc_;
LabelMap curLabels_;
LabelJumpList unresolvedJumps_;
LabelStmt* breakDest_;
LabelStmt* continueDest_;
CaseLabelList* caseLabels_;
LabelStmt* defaultLabel_;
};
#endif

View File

@@ -0,0 +1,86 @@
#pragma once
#ifndef _WGTCC_SCANNER_H_
#define _WGTCC_SCANNER_H_
#include "error.h"
#include "encoding.h"
#include "token.h"
#include <string>
#include <cassert>
class Scanner {
public:
explicit Scanner(const Token* tok)
: Scanner(&tok->str_, tok->loc_) {}
Scanner(const std::string* text, const SourceLocation& loc)
: Scanner(text, loc.filename_, loc.line_, loc.column_) {}
explicit Scanner(const std::string* text,
const std::string* filename=nullptr,
unsigned line=1, unsigned column=1)
: text_(text), tok_(Token::END) {
// TODO(wgtdkp): initialization
p_ = &(*text_)[0];
loc_ = {filename, p_, line, 1};
}
virtual ~Scanner() {}
Scanner(const Scanner& other) = delete;
Scanner& operator=(const Scanner& other) = delete;
// Scan plain text and generate tokens in ts.
// The param 'ts' need not be empty, if so, the tokens
// are inserted at the *header* of 'ts'.
// The param 'ws' tells if there is leading white space
// before this token, it is only SkipComment() that will
// set this param.
Token* Scan(bool ws=false);
void Tokenize(TokenSequence& ts);
static std::string ScanHeadName(const Token* lhs, const Token* rhs);
Encoding ScanCharacter(int& val);
Encoding ScanLiteral(std::string& val);
std::string ScanIdentifier();
private:
Token* SkipIdentifier();
Token* SkipNumber();
Token* SkipLiteral();
Token* SkipCharacter();
Token* MakeToken(int tag);
Token* MakeNewLine();
Encoding ScanEncoding(int c);
int ScanEscaped();
int ScanHexEscaped();
int ScanOctEscaped(int c);
int ScanUCN(int len);
void SkipWhiteSpace();
void SkipComment();
bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); }
bool IsOctal(int c) { return '0' <= c && c <= '7'; }
int XDigit(int c);
bool Empty() const { return *p_ == 0; }
int Peek();
bool Test(int c) { return Peek() == c; };
int Next();
void PutBack();
bool Try(int c) {
if (Peek() == c) {
Next();
return true;
}
return false;
};
void Mark() { tok_.loc_ = loc_; };
const std::string* text_;
SourceLocation loc_;
Token tok_;
const char* p_;
};
std::string* ReadFile(const std::string& filename);
#endif

View File

@@ -0,0 +1,72 @@
#pragma once
#ifndef _WGTCC_SCOPE_H_
#define _WGTCC_SCOPE_H_
#include <iostream>
#include <map>
#include <string>
#include <vector>
class Identifier;
class Token;
enum ScopeType {
S_FILE,
S_PROTO,
S_BLOCK,
S_FUNC,
};
class Scope {
friend class StructType;
using TagList = std::vector<Identifier*>;
using IdentMap = std::map<std::string, Identifier*>;
public:
explicit Scope(Scope* parent, enum ScopeType type)
: parent_(parent), type_(type) {}
~Scope() {}
Scope* Parent() { return parent_; }
void SetParent(Scope* parent) { parent_ = parent; }
enum ScopeType Type() const { return type_; }
Identifier* Find(const Token* tok);
Identifier* FindInCurScope(const Token* tok);
Identifier* FindTag(const Token* tok);
Identifier* FindTagInCurScope(const Token* tok);
TagList AllTagsInCurScope() const;
void Insert(Identifier* ident);
void Insert(const std::string& name, Identifier* ident);
void InsertTag(Identifier* ident);
void Print();
bool operator==(const Scope& other) const { return type_ == other.type_; }
IdentMap::iterator begin() { return identMap_.begin(); }
IdentMap::iterator end() { return identMap_.end(); }
size_t size() const { return identMap_.size(); }
private:
Identifier* Find(const std::string& name);
Identifier* FindInCurScope(const std::string& name);
Identifier* FindTag(const std::string& name);
Identifier* FindTagInCurScope(const std::string& name);
std::string TagName(const std::string& name) {
return name + "@:tag";
}
static bool IsTagName(const std::string& name) {
return name.size() > 5 && name[name.size() - 5] == '@';
}
const Scope& operator=(const Scope& other);
Scope(const Scope& scope);
Scope* parent_;
enum ScopeType type_;
IdentMap identMap_;
};
#endif

429
include/triton/lang/token.h Normal file
View File

@@ -0,0 +1,429 @@
#pragma once
#ifndef _WGTCC_TOKEN_H_
#define _WGTCC_TOKEN_H_
#include "error.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <unordered_map>
class Generator;
class Parser;
class Scanner;
class Token;
class TokenSequence;
using HideSet = std::set<std::string>;
using TokenList = std::list<const Token*>;
struct SourceLocation {
const std::string* filename_;
const char* lineBegin_;
unsigned line_;
unsigned column_;
const char* Begin() const {
return lineBegin_ + column_ - 1;
}
};
class Token {
friend class Scanner;
public:
enum {
// Punctuators
LPAR = '(',
RPAR = ')',
LSQB = '[',
RSQB = ']',
COLON = ':',
COMMA = ',',
SEMI = ';',
ADD = '+',
SUB = '-',
MUL = '*',
DIV = '/',
OR = '|',
AND = '&',
XOR = '^',
LESS = '<',
GREATER = '>',
EQUAL = '=',
DOT = '.',
MOD = '%',
LBRACE = '{',
RBRACE = '}',
TILDE = '~',
NOT = '!',
COND = '?',
SHARP = '#',
MATMUL = '@',
NEW_LINE = '\n',
DSHARP = 128, // '##'
PTR,
INC,
DEC,
LEFT,
RIGHT,
LE,
GE,
EQ,
NE,
LOGICAL_AND,
LOGICAL_OR,
MUL_ASSIGN,
DIV_ASSIGN,
MOD_ASSIGN,
ADD_ASSIGN,
SUB_ASSIGN,
LEFT_ASSIGN,
RIGHT_ASSIGN,
AND_ASSIGN,
XOR_ASSIGN,
OR_ASSIGN,
ELLIPSIS,
MASKED_DEREF,
// Punctuators end
// KEYWORD BEGIN
// TYPE QUALIFIER BEGIN
CONST,
RESTRICT,
VOLATILE,
ATOMIC,
// TYPE QUALIFIER END
// TYPE SPECIFIER BEGIN
VOID,
CHAR,
SHORT,
INT,
LONG,
HALF,
FLOAT,
DOUBLE,
SIGNED,
UNSIGNED,
BOOL, // _Bool
COMPLEX, // _Complex
STRUCT,
UNION,
ENUM,
// TYPE SPECIFIER END
ATTRIBUTE, // GNU extension __attribute__
// FUNCTION SPECIFIER BEGIN
INLINE,
NORETURN, // _Noreturn
// FUNCTION SPECIFIER END
// TILE ARITHMETICS BEGIN
NEWAXIS,
MAX,
MIN,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas
// For syntactic convenience
STATIC_ASSERT, // _Static_assert
// STORAGE CLASS SPECIFIER BEGIN
TYPEDEF,
EXTERN,
STATIC,
THREAD, // _Thread_local
AUTO,
GLOBAL,
CMEM, // constant memory
// STORAGE CLASS SPECIFIER END
BREAK,
CASE,
CONTINUE,
DEFAULT,
DO,
ELSE,
FOR,
GOTO,
IF,
RETURN,
SIZEOF,
SWITCH,
WHILE,
ALIGNOF, // _Alignof
GENERIC, // _Generic
IMAGINARY, // _Imaginary
// KEYWORD END
IDENTIFIER,
CONSTANT,
I_CONSTANT,
C_CONSTANT,
F_CONSTANT,
LITERAL,
// For the parser, a identifier is a typedef name or user defined type
POSTFIX_INC,
POSTFIX_DEC,
PREFIX_INC,
PREFIX_DEC,
ADDR, // '&'
DEREF, // '*'
PLUS,
MINUS,
CAST,
REDUCE,
// For preprocessor
PP_IF,
PP_IFDEF,
PP_IFNDEF,
PP_ELIF,
PP_ELSE,
PP_ENDIF,
PP_INCLUDE,
PP_DEFINE,
PP_UNDEF,
PP_LINE,
PP_ERROR,
PP_PRAGMA,
PP_NONE,
PP_EMPTY,
IGNORE,
INVALID,
END,
NOTOK = -1,
};
static Token* New(int tag);
static Token* New(const Token& other);
static Token* New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws=false);
Token& operator=(const Token& other) {
tag_ = other.tag_;
ws_ = other.ws_;
loc_ = other.loc_;
str_ = other.str_;
hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr;
return *this;
}
virtual ~Token() {}
// Token::NOTOK represents not a kw.
static int KeyWordTag(const std::string& key) {
auto kwIter = kwTypeMap_.find(key);
if (kwTypeMap_.end() == kwIter)
return Token::NOTOK; // Not a key word type
return kwIter->second;
}
static bool IsKeyWord(const std::string& name);
static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; }
bool IsKeyWord() const { return IsKeyWord(tag_); }
bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; }
bool IsLiteral() const { return tag_ == LITERAL; }
bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; }
bool IsIdentifier() const { return IDENTIFIER == tag_; }
bool IsEOF() const { return tag_ == Token::END; }
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; }
static const char* Lexeme(int tag) {
auto iter = tagLexemeMap_.find(tag);
if (iter == tagLexemeMap_.end())
return nullptr;
return iter->second;
}
int tag_;
// 'ws_' standards for weither there is preceding white space
// This is to simplify the '#' operator(stringize) in macro expansion
bool ws_ { false };
SourceLocation loc_;
std::string str_;
HideSet* hs_ { nullptr };
private:
explicit Token(int tag): tag_(tag) {}
Token(int tag, const SourceLocation& loc,
const std::string& str, bool ws=false)
: tag_(tag), ws_(ws), loc_(loc), str_(str) {}
Token(const Token& other) {
*this = other;
}
static const std::unordered_map<std::string, int> kwTypeMap_;
static const std::unordered_map<int, const char*> tagLexemeMap_;
};
class TokenSequence {
friend class Preprocessor;
public:
TokenSequence(): tokList_(new TokenList()),
begin_(tokList_->begin()), end_(tokList_->end()) {}
explicit TokenSequence(Token* tok) {
TokenSequence();
InsertBack(tok);
}
explicit TokenSequence(TokenList* tokList)
: tokList_(tokList),
begin_(tokList->begin()),
end_(tokList->end()) {}
TokenSequence(TokenList* tokList,
TokenList::iterator begin,
TokenList::iterator end)
: tokList_(tokList), begin_(begin), end_(end) {}
~TokenSequence() {}
TokenSequence(const TokenSequence& other) { *this = other; }
const TokenSequence& operator=(const TokenSequence& other) {
tokList_ = other.tokList_;
begin_ = other.begin_;
end_ = other.end_;
return *this;
}
void Copy(const TokenSequence& other) {
tokList_ = new TokenList(other.begin_, other.end_);
begin_ = tokList_->begin();
end_ = tokList_->end();
for (auto iter = begin_; iter != end_; ++iter)
*iter = Token::New(**iter);
}
void UpdateHeadLocation(const SourceLocation& loc) {
assert(!Empty());
auto tok = const_cast<Token*>(Peek());
tok->loc_ = loc;
}
void FinalizeSubst(bool leadingWS, const HideSet& hs) {
auto ts = *this;
while (!ts.Empty()) {
auto tok = const_cast<Token*>(ts.Next());
if (!tok->hs_)
tok->hs_ = new HideSet(hs);
else
tok->hs_->insert(hs.begin(), hs.end());
}
// Even if the token sequence is empty
const_cast<Token*>(Peek())->ws_ = leadingWS;
}
const Token* Expect(int expect);
bool Try(int tag) {
if (Peek()->tag_ == tag) {
Next();
return true;
}
return false;
}
bool Test(int tag) { return Peek()->tag_ == tag; }
const Token* Next() {
auto ret = Peek();
if (!ret->IsEOF()) {
++begin_;
Peek(); // May skip newline token, but why ?
} else {
++exceed_end;
}
return ret;
}
void PutBack() {
assert(begin_ != tokList_->begin());
if (exceed_end > 0) {
--exceed_end;
} else {
--begin_;
if ((*begin_)->tag_ == Token::NEW_LINE)
PutBack();
}
}
const Token* Peek() const;
const Token* Peek2() {
if (Empty())
return Peek(); // Return the Token::END
Next();
auto ret = Peek();
PutBack();
return ret;
}
const Token* Back() const {
auto back = end_;
return *--back;
}
void PopBack() {
assert(!Empty());
assert(end_ == tokList_->end());
auto size_eq1 = tokList_->back() == *begin_;
tokList_->pop_back();
end_ = tokList_->end();
if (size_eq1)
begin_ = end_;
}
TokenList::iterator Mark() { return begin_; }
void ResetTo(TokenList::iterator mark) { begin_ = mark; }
bool Empty() const { return Peek()->tag_ == Token::END; }
void InsertBack(TokenSequence& ts) {
auto pos = tokList_->insert(end_, ts.begin_, ts.end_);
if (begin_ == end_) {
begin_ = pos;
}
}
void InsertBack(const Token* tok) {
auto pos = tokList_->insert(end_, tok);
if (begin_ == end_) {
begin_ = pos;
}
}
// If there is preceding newline
void InsertFront(TokenSequence& ts) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, ts.begin_, ts.end_);
}
void InsertFront(const Token* tok) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, tok);
}
bool IsBeginOfLine() const;
TokenSequence GetLine();
void SetParser(Parser* parser) { parser_ = parser; }
void Print(FILE* fp=stdout) const;
void Print(std::string *str) const;
private:
// Find a insert position with no preceding newline
TokenList::iterator GetInsertFrontPos() {
auto pos = begin_;
if (pos == tokList_->begin())
return pos;
--pos;
while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE)
--pos;
return ++pos;
}
TokenList* tokList_;
mutable TokenList::iterator begin_;
TokenList::iterator end_;
Parser* parser_ {nullptr};
int exceed_end {0};
};
#endif

464
include/triton/lang/type.h Normal file
View File

@@ -0,0 +1,464 @@
#pragma once
#ifndef _WGTCC_TYPE_H_
#define _WGTCC_TYPE_H_
#include "mem_pool.h"
#include "scope.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <list>
class Scope;
class Token;
class Expr;
class Type;
class QualType;
class VoidType;
class Identifier;
class Object;
class Constant;
class ArithmType;
class DerivedType;
class ArrayType;
class TileType;
class FuncType;
class PointerType;
class StructType;
class EnumType;
enum {
// Storage class specifiers
S_TYPEDEF = 0x01,
S_EXTERN = 0x02,
S_STATIC = 0x04,
S_THREAD = 0x08,
S_CONSTANT = 0x10,
S_GLOBAL = 0x20,
// Type specifier
T_SIGNED = 0x40,
T_UNSIGNED = 0x80,
T_CHAR = 0x100,
T_SHORT = 0x200,
T_INT = 0x400,
T_LONG = 0x800,
T_VOID = 0x1000,
T_HALF = 0x2000,
T_FLOAT = 0x4000,
T_DOUBLE = 0x8000,
T_BOOL = 0x10000,
T_COMPLEX = 0x20000,
// T_ATOMIC = 0x40000,
T_STRUCT_UNION = 0x80000,
T_ENUM = 0x100000,
T_TYPEDEF_NAME = 0x200000,
T_LLONG = 0x4000000,
// Function specifier
F_INLINE = 0x8000000,
F_NORETURN = 0x10000000,
};
struct Qualifier {
enum {
CONST = 0x01,
RESTRICT = 0x02,
VOLATILE = 0x04,
CMEM = 0x08,
MASK = CONST | RESTRICT | VOLATILE | CMEM
};
};
class QualType {
public:
QualType(Type* ptr, int quals=0x00)
: ptr_(reinterpret_cast<intptr_t>(ptr)) {
assert((quals & ~Qualifier::MASK) == 0);
ptr_ |= quals;
}
operator bool() const { return !IsNull(); }
bool IsNull() const { return GetPtr() == nullptr; }
const Type* GetPtr() const {
return reinterpret_cast<const Type*>(ptr_ & ~Qualifier::MASK);
}
Type* GetPtr() {
return reinterpret_cast<Type*>(ptr_ & ~Qualifier::MASK);
}
Type& operator*() { return *GetPtr(); }
const Type& operator*() const { return *GetPtr(); }
Type* operator->() { return GetPtr(); }
const Type* operator->() const { return GetPtr(); }
// Indicate whether the specified types are identical(exclude qualifiers).
friend bool operator==(QualType lhs, QualType rhs) {
return lhs.operator->() == rhs.operator->();
}
friend bool operator!=(QualType lhs, QualType rhs) {
return !(lhs == rhs);
}
int Qual() const { return ptr_ & 0x07; }
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
private:
intptr_t ptr_;
};
class Type {
public:
static const int intWidth_ = 4;
static const int machineWidth_ = 8;
bool operator!=(const Type& other) const = delete;
bool operator==(const Type& other) const = delete;
virtual bool Compatible(const Type& other) const {
return complete_ == other.complete_;
}
virtual ~Type() {}
// For Debugging
virtual std::string Str() const = 0;
virtual int Width() const = 0;
virtual int Align() const { return Width(); }
static int MakeAlign(int offset, int align) {
if ((offset % align) == 0)
return offset;
if (offset >= 0)
return offset + align - (offset % align);
else
return offset - align - (offset % align);
}
static QualType MayCast(QualType type, bool inProtoScope=false);
bool Complete() const { return complete_; }
void SetComplete(bool complete) const { complete_ = complete; }
bool IsReal() const { return IsInteger() || IsFloat(); };
virtual bool IsScalar() const { return false; }
virtual bool IsFloat() const { return false; }
virtual bool IsInteger() const { return false; }
virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; }
virtual bool IsTile() const { return ToTile() != nullptr; }
const Type* ScalarType() const;
Type* ScalarType();
virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; }
virtual ArithmType* ToArithm() { return nullptr; }
virtual const ArithmType* ToArithm() const { return nullptr; }
virtual ArrayType* ToArray() { return nullptr; }
virtual const ArrayType* ToArray() const { return nullptr; }
virtual TileType* ToTile() { return nullptr; }
virtual const TileType* ToTile() const { return nullptr; }
virtual FuncType* ToFunc() { return nullptr; }
virtual const FuncType* ToFunc() const { return nullptr; }
virtual PointerType* ToPointer() { return nullptr; }
virtual const PointerType* ToPointer() const { return nullptr; }
virtual DerivedType* ToDerived() { return nullptr; }
virtual const DerivedType* ToDerived() const { return nullptr; }
virtual StructType* ToStruct() { return nullptr; }
virtual const StructType* ToStruct() const { return nullptr; }
protected:
Type(MemPool* pool, bool complete)
: complete_(complete), pool_(pool) {}
mutable bool complete_;
MemPool* pool_;
};
class VoidType : public Type {
public:
static VoidType* New();
virtual ~VoidType() {}
virtual VoidType* ToVoid() { return this; }
virtual const VoidType* ToVoid() const { return this; }
virtual bool Compatible(const Type& other) const { return other.ToVoid(); }
virtual int Width() const {
// Non-standard GNU extension
return 1;
}
virtual std::string Str() const { return "void:1"; }
protected:
explicit VoidType(MemPool* pool): Type(pool, false) {}
};
class ArithmType : public Type {
public:
static ArithmType* New(int typeSpec);
virtual ~ArithmType() {}
virtual ArithmType* ToArithm() { return this; }
virtual const ArithmType* ToArithm() const { return this; }
virtual bool Compatible(const Type& other) const {
// C11 6.2.7 [1]: Two types have compatible type if their types are the same
// But I would to loose this constraints: integer and pointer are compatible
// if (IsInteger() && other.ToPointer())
// return other.Compatible(*this);
return this == &other;
}
virtual int Width() const;
virtual std::string Str() const;
virtual bool IsScalar() const { return true; }
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
virtual bool IsFloat() const {
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
}
virtual bool IsBool() const { return tag_ & T_BOOL; }
bool IsComplex() const { return tag_ & T_COMPLEX; }
int Tag() const { return tag_; }
int Rank() const;
static ArithmType* IntegerPromote(ArithmType* type) {
assert(type->IsInteger());
if (type->Rank() < ArithmType::New(T_INT)->Rank())
return ArithmType::New(T_INT);
return type;
}
static ArithmType* MaxType(ArithmType* lhsType,
ArithmType* rhsType);
protected:
explicit ArithmType(MemPool* pool, int spec)
: Type(pool, true), tag_(Spec2Tag(spec)) {}
private:
static int Spec2Tag(int spec);
int tag_;
};
class DerivedType : public Type {
public:
QualType Derived() const { return derived_; }
void SetDerived(QualType derived) { derived_ = derived; }
virtual DerivedType* ToDerived() { return this; }
virtual const DerivedType* ToDerived() const { return this; }
protected:
DerivedType(MemPool* pool, QualType derived)
: Type(pool, true), derived_(derived) {}
QualType derived_;
};
class PointerType : public DerivedType {
public:
static PointerType* New(QualType derived);
virtual ~PointerType() {}
virtual PointerType* ToPointer() { return this; }
virtual const PointerType* ToPointer() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 8; }
virtual bool IsScalar() const { return true; }
virtual bool IsVoidPointer() const { return derived_->ToVoid(); }
virtual std::string Str() const {
return derived_->Str() + "*:" + std::to_string(Width());
}
protected:
PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {}
};
class ArrayType : public DerivedType {
public:
static ArrayType* New(int len, QualType eleType);
static ArrayType* New(Expr* expr, QualType eleType);
virtual ~ArrayType() { /*delete derived_;*/ }
virtual ArrayType* ToArray() { return this; }
virtual const ArrayType* ToArray() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const {
return Complete() ? (derived_->Width() * len_): 0;
}
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[]:" + std::to_string(Width());
}
int GetElementOffset(int idx) const { return derived_->Width() * idx; }
int Len() const { return len_; }
void SetLen(int len) { len_ = len; }
bool Variadic() const { return lenExpr_ != nullptr; }
protected:
ArrayType(MemPool* pool, Expr* lenExpr, QualType derived)
: DerivedType(pool, derived),
lenExpr_(lenExpr), len_(0) {
SetComplete(false);
}
ArrayType(MemPool* pool, int len, QualType derived)
: DerivedType(pool, derived),
lenExpr_(nullptr), len_(len) {
SetComplete(len_ >= 0);
}
const Expr* lenExpr_;
int len_;
};
class TileType : public DerivedType {
public:
using ShapeExpr = std::vector<Expr*>;
using ShapeInt = std::vector<int>;
public:
static TileType* New(const ShapeExpr& expr, QualType eleType);
static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { }
virtual TileType* ToTile() { return this; }
virtual const TileType* ToTile() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width());
}
ShapeInt Shape() { return shape_; }
int NumEle() const {
int ret = 1;
for(int s: shape_)
ret *= s;
return ret;
}
protected:
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
: DerivedType(pool, derived),
shapeExpr_(expr) {
bool isComplete = true;
for(Expr* s: shapeExpr_)
isComplete = isComplete && !s;
SetComplete(isComplete);
}
TileType(MemPool* pool, const ShapeInt& shape, QualType derived)
: DerivedType(pool, derived),
shape_(shape) {
bool isComplete = true;
for(int s: shape_)
isComplete = isComplete && (s>=0);
SetComplete(isComplete);
}
protected:
ShapeExpr shapeExpr_;
ShapeInt shape_;
};
class FuncType : public DerivedType {
public:
using ParamList = std::vector<Object*>;
public:
static FuncType* New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params);
virtual ~FuncType() {}
virtual FuncType* ToFunc() { return this; }
virtual const FuncType* ToFunc() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 1; }
virtual std::string Str() const;
const ParamList& Params() const { return params_; }
void SetParams(const ParamList& params) { params_ = params; }
bool Variadic() const { return variadic_; }
bool IsInline() const { return inlineNoReturn_ & F_INLINE; }
bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; }
protected:
FuncType(MemPool* pool, QualType derived, int inlineReturn,
bool variadic, const ParamList& params)
: DerivedType(pool, derived), inlineNoReturn_(inlineReturn),
variadic_(variadic), params_(params) {
SetComplete(false);
}
private:
int inlineNoReturn_;
bool variadic_;
ParamList params_;
};
class StructType : public Type {
public:
using MemberList = std::list<Object*>;
using Iterator = std::list<Object*>::iterator;
public:
static StructType* New(bool isStruct,
bool hasTag,
Scope* parent);
virtual ~StructType() {}
virtual StructType* ToStruct() { return this; }
virtual const StructType* ToStruct() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return width_; }
virtual int Align() const { return align_; }
virtual std::string Str() const;
// struct/union
void AddMember(Object* member);
void AddBitField(Object* member, int offset);
bool IsStruct() const { return isStruct_; }
Object* GetMember(const std::string& member);
Scope* MemberMap() { return memberMap_; }
MemberList& Members() { return members_; }
int Offset() const { return offset_; }
bool HasTag() const { return hasTag_; }
void MergeAnony(Object* anony);
void Finalize();
protected:
// Default is incomplete
StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent);
StructType(const StructType& other);
private:
void CalcWidth();
bool isStruct_;
bool hasTag_;
Scope* memberMap_;
MemberList members_;
int offset_;
int width_;
int align_;
int bitFieldAlign_;
};
#endif

View File

@@ -0,0 +1,56 @@
#pragma once
#ifndef _WGTCC_VISITOR_H_
#define _WGTCC_VISITOR_H_
class BinaryOp;
class UnaryOp;
class TransOp;
class ConditionalOp;
class FuncCall;
class Identifier;
class Object;
class Enumerator;
class Constant;
class TempVar;
class Declaration;
class IfStmt;
class ForStmt;
class JumpStmt;
class ReturnStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
class Visitor {
public:
virtual ~Visitor() {}
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
virtual void VisitTransOp(TransOp* trans) = 0;
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
virtual void VisitEnumerator(Enumerator* enumer) = 0;
virtual void VisitIdentifier(Identifier* ident) = 0;
virtual void VisitObject(Object* obj) = 0;
virtual void VisitConstant(Constant* cons) = 0;
virtual void VisitTempVar(TempVar* tempVar) = 0;
virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0;
virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0;
virtual void VisitFuncDef(FuncDef* funcDef) = 0;
virtual void VisitTranslationUnit(TranslationUnit* unit) = 0;
};
#endif

View File

@@ -0,0 +1,82 @@
#pragma once
#ifndef _TRITON_RUNTIME_ARG_H_
#define _TRITON_RUNTIME_ARG_H_
#include <string>
#include <stdexcept>
namespace triton{
namespace driver{
class buffer;
}
namespace runtime {
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T: return 1;
case INT8_T: return 1;
case INT16_T: return 2;
case INT32_T: return 4;
case INT64_T: return 8;
case HALF_T: return 2;
case FLOAT_T: return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
inline bool is_int_type(arg_type ty){
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
ty == INT32_T || ty == INT64_T;
}
class arg {
private:
union value_t {
bool int1;
int8_t int8;
int16_t int16;
int32_t int32;
int64_t int64;
float fp32;
double fp64;
driver::buffer* buf;
};
public:
// construct from primitive types
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
// accessors
arg_type type() const { return ty_; }
void* data() const { return (void*)&val_; }
private:
arg_type ty_;
value_t val_;
};
}
}
#endif

View File

@@ -0,0 +1,122 @@
#pragma once
#ifndef _TRITON_RUNTIME_FUNCTION_H_
#define _TRITON_RUNTIME_FUNCTION_H_
#include <vector>
#include <string>
#include <memory>
#include <functional>
// codegen
#include "triton/ir/context.h"
#include "triton/codegen/target.h"
#include "triton/runtime/arg.h"
namespace llvm {
class Module;
class LLVMContext;
}
class Parser;
namespace triton {
namespace driver{
class module;
class stream;
class kernel;
class context;
class device;
}
namespace lang{
class translation_unit;
}
namespace codegen{
namespace analysis{
class tiles;
}
}
namespace ir {
class module;
class function;
class context;
}
namespace runtime{
typedef std::vector<size_t> grid_t;
typedef std::map<std::string, size_t> params_t;
template<typename T> inline T convert(const std::string& name);
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
class function {
public:
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
std::map<std::string, std::string> defines;
size_t num_warps;
};
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
private:
class caller {
public:
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt_);
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
const options_t opt() const { return opt_; }
private:
std::shared_ptr<driver::kernel> bin_;
std::shared_ptr<driver::module> parent_;
std::vector<arg_type> param_tys_;
options_t opt_;
};
private:
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
private:
triton::lang::translation_unit *make_ast(const std::string &src);
std::unique_ptr<ir::module> make_ir(Parser &parser);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
public:
static std::string preheader();
public:
function(const std::string& src, const options_space_t& opt = options_space_t());
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
void set_cst(const std::string& name, void* data, size_t n_bytes);
private:
ir::context ctx_;
std::string src_;
options_space_t opt_space_;
std::map<cache_key_t, caller> cache_;
std::map<std::string, std::vector<char>> cst_;
};
}
}
#endif

View File

@@ -0,0 +1,59 @@
#pragma once
#ifndef _TRITON_TOOLS_BENCH_H_
#define _TRITON_TOOLS_BENCH_H_
#include <chrono>
#include <functional>
#include <algorithm>
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
namespace triton{
namespace tools{
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
void start()
{ _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
{
// const driver::device * device = stream->context()->device();
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
stream->synchronize();
while(total_time*1e-9 < 1e-2){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
if(normalize)
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return *std::min_element(times.begin(), times.end());
}
}
}
#endif

View File

@@ -0,0 +1,67 @@
#pragma once
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace tools{
template<class node_t>
class graph {
typedef std::map<node_t, std::set<node_t>> edges_t;
public:
typedef std::map<size_t, std::vector<node_t>> cmap_t;
typedef std::map<node_t, size_t> nmap_t;
private:
void connected_components_impl(node_t x, std::set<node_t> &nodes,
nmap_t* nmap, cmap_t* cmap, int id) const {
if(nmap)
(*nmap)[x] = id;
if(cmap)
(*cmap)[id].push_back(x);
if(nodes.find(x) != nodes.end()) {
nodes.erase(x);
for(const node_t &y: edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id);
}
}
public:
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
if(cmap)
cmap->clear();
if(nmap)
nmap->clear();
std::set<node_t> nodes = nodes_;
unsigned id = 0;
while(!nodes.empty())
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
void add_edge(node_t x, node_t y) {
nodes_.insert(x);
nodes_.insert(y);
edges_[x].insert(y);
edges_[y].insert(x);
}
void clear() {
nodes_.clear();
edges_.clear();
}
private:
std::set<node_t> nodes_;
edges_t edges_;
};
}
}
#endif

View File

@@ -0,0 +1,56 @@
/*
* Copyright (c) 2015, PHILIPPE TILLET. All rights reserved.
*
* This file is part of ISAAC.
*
* ISAAC is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
#ifndef TDL_TOOLS_SYS_GETENV_HPP
#define TDL_TOOLS_SYS_GETENV_HPP
#include <string>
#include <cstdlib>
namespace triton
{
namespace tools
{
inline std::string getenv(const char * name)
{
#ifdef _MSC_VER
char* cache_path = 0;
std::size_t sz = 0;
_dupenv_s(&cache_path, &sz, name);
#else
const char * cstr = std::getenv(name);
#endif
if(!cstr)
return "";
std::string result(cstr);
#ifdef _MSC_VER
free(cache_path);
#endif
return result;
}
}
}
#endif

View File

@@ -0,0 +1,68 @@
/*
* Copyright (c) 2015, PHILIPPE TILLET. All rights reserved.
*
* This file is part of ISAAC.
*
* ISAAC is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
#ifndef TDL_TOOLS_SYS_MKDIR_HPP
#define TDL_TOOLS_SYS_MKDIR_HPP
#include <cstring>
#include <string>
#include <cstdlib>
#include <sys/stat.h>
#include <errno.h>
#if defined(_WIN32)
#include <direct.h>
#endif
namespace triton
{
namespace tools
{
inline int mkdir(std::string const & path)
{
#if defined(_WIN32)
return _mkdir(path.c_str());
#else
return ::mkdir(path.c_str(), 0777);
#endif
}
inline int mkpath(std::string const & path)
{
int status = 0;
size_t pp = 0;
size_t sp;
while ((sp = path.find('/', pp)) != std::string::npos)
{
if (sp != pp){
status = mkdir(path.substr(0, sp));
}
pp = sp + 1;
}
return (status==0 || errno==EEXIST)?0:-1;
}
}
}
#endif

View File

@@ -0,0 +1,100 @@
#pragma once
#ifndef _TRITON_TOOLS_THREAD_POOL_H_
#define _TRITON_TOOLS_THREAD_POOL_H_
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif

View File

@@ -0,0 +1,514 @@
#include "triton/codegen/analysis/align.h"
#include "triton/ir/utils.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace analysis{
// Function for extended Euclidean Algorithm
int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map[i] = value;
}
/*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
ir::value* op = x->get_operand(0);
std::vector<cst_info> result;
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
}
/*
* max contiguous
*/
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
}
add_to_cache(x, result, max_contiguous_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto contiguous = populate_max_contiguous(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
}
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
if(x->is_int_mult()){
unsigned lvalue = 1, rvalue = 1;
if(rhs_cst_info[d].value == 1)
lvalue = lhs_max_contiguous[d];
if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst > 0)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst > 0)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
if(x->is_int_rem() && rhs[d] > 1)
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_shl())
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++)
result[d] = gcd(lhs[d], rhs[d]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of){
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}
}
return add_to_cache(v, {1}, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
return add_to_cache(x, {128}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_starting_multiple_phi(x);
return populate_starting_multiple_default(v);
}
unsigned align::get(ir::value *v, unsigned ax) const {
unsigned starting_multiple = starting_multiple_.at(v)[ax];
unsigned max_contiguous = max_contiguous_.at(v)[ax];
return std::min(starting_multiple, max_contiguous);
}
std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
}
}
}
}

View File

@@ -0,0 +1,107 @@
#include <algorithm>
#include <climits>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void allocation::run(ir::module &mod) {
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<shared_layout*> I;
for(auto x: liveness_->get())
I.push_back(x.first);
std::vector<shared_layout*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<shared_layout*> V;
std::map<shared_layout*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
res = res && !val.second.intersect(xj);
return res;
});
if(j_it != J.end()){
unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
if(xh.start < xj.start)
H.insert({w, segment{xh.start, xj.end}});
if(xj.end < xh.end)
H.insert({w, segment{xj.start, xh.end}});
V.push_back(*j_it);
J.erase(j_it);
}
}
// Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
for(shared_layout* y: V){
if(x == y)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->get_size();
unsigned NY = y->get_size();
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y))
&& XS.intersect(YS))
interferences[x].insert(y);
}
// Initialize colors
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(shared_layout* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(shared_layout* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
}
// Assigns first available color
auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It);
}
// Finalize allocation
for(shared_layout* x: V){
unsigned Adj = 0;
for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
}
}
}
}

View File

@@ -0,0 +1,147 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
namespace triton{
namespace codegen{
namespace analysis{
axes::axes() {}
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
graph_.add_edge({i, current++}, {arg, d});
}
}
void axes::update_graph_reshape(ir::instruction *i) {
auto* reshape = static_cast<ir::reshape_inst*>(i);
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_tile_shapes();
auto res_shapes = reshape->get_type()->get_tile_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
for(unsigned d = 0; d < res_shapes.size(); d ++){
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
graph_.add_edge({i, d}, {op, current++});
else
graph_.add_edge({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
}
}
void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0);
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_tile_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
graph_.add_edge({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_tile_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_tile_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_tile_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
graph_.add_edge({i, d}, {i, d});
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);;
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}
int axes::get(ir::value *value, unsigned dim) {
return axes_.at({value, dim});
}
std::vector<int> axes::get(ir::value *value) {
std::vector<int> result;
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
result.push_back(this->get(value, d));
return result;
}
void axes::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
}
}
}
}

View File

@@ -0,0 +1,442 @@
#include <algorithm>
#include <numeric>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(v);
}
}
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
result = v;
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v;
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values_)
extract_io_use(v, ptr);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
});
if(*largest){
auto max_contiguous = align->contiguous(*largest);
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
}
}
size_t data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
return std::distance(axes_.begin(), it);
}
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
if(is_latch_0)
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
shared_layout::shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size_ = 0;
// double-buffering
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
ir::value* hmma_dot_b = nullptr;
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order_ != arg_order) {
pad = 4;
}
shape_[order_[0]] += pad;
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
}
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
y->get_type()->get_tile_ranks1();
};
std::vector<ir::value*> lvalue = values;
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v);
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_)
create(x.first, x.second);
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_tile_shapes();
unsigned shape_ax = shapes[axis];
scanline_layout *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts(axis);
unsigned depth = shape_ax / per_thread;
shapes[axis] = depth;
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
mma884_layout* in_layout = get(val)->to_mma884();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
// create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;
}
});
}
}
}
}

View File

@@ -0,0 +1,57 @@
#include <climits>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
}
}
// create live intervals
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
intervals_[layout] = segment{start, end};
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,326 @@
#include <numeric>
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/target.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/IRBuilder.h"
namespace triton{
namespace codegen{
using namespace llvm;
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
// Grid construction
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_shared_layout::create(ir::value *v) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto double_buffer = layout_->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = offset_;
// base pointer
Value *ptr = ptr_;
if(double_buffer && v == double_buffer->latch)
ptr = next_ptr_;
else if(double_buffer && v == double_buffer->first)
ptr = pre_ptr_;
// create tile
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
}
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> order(rank);
// compute axes
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
// compute order
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
size_t posx = layout_->find_axis(axx);
size_t posy = layout_->find_axis(axy);
if(posx < rank && posy < rank)
return layout_->get_order(posx) < layout_->get_order(posy);
return false;
};
std::sort(order.begin(), order.end(), cmp);
return new distributed_tile(ty, shapes, order, axes, *builder_);
}
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shape = layout->get_shape();
if(shape.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shape.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// mma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// mma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
}
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::scanline_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->get_order();
const auto& shape = layout->get_shape();
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts * mts;
unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
}
}
}
}

View File

@@ -0,0 +1,214 @@
#include <numeric>
#include <iostream>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"
namespace triton{
namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order_[0]]++;
while(id[order_[k]] == axes_[order_[k]].values.size()){
if(k == id.size() - 1)
return;
id[order_[k++]] = 0;
id[order_[k]]++;
}
k = 0;
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
if(end < 0)
end = ordered_indices_.size() + end + 1;
for(unsigned i = start; i < end; i++)
fn(ordered_indices_[i]);
}
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
int rank = sizes.size();
int len = 1;
for(int s: sizes)
len *= s;
for(int i = 0; i < len; i++){
indices_t idx(rank);
int current = i;
for(int k = 0; k < rank; k++){
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
current = current / sizes[k];
}
fn(idx);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
const std::vector<int>& perm, const std::vector<int>& order,
indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < strides.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
}
}

174
lib/codegen/target.cc Normal file
View File

@@ -0,0 +1,174 @@
#include "triton/codegen/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/IRBuilder.h"
#include <iostream>
using namespace llvm;
namespace triton{
namespace codegen{
// base
bool target::is_gpu() const {
return is_gpu_;
}
// AMD
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
}
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {});
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::r600_read_ngroups_x,
Intrinsic::r600_read_ngroups_y,
Intrinsic::r600_read_ngroups_z
};
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_num_group, {});
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
// NVIDIA
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(fn),
MDString::get(ctx, "kernel"),
ValueAsMetadata::get(builder.getInt32(1))
};
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
}
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_tid_x,
Intrinsic::nvvm_read_ptx_sreg_tid_y,
Intrinsic::nvvm_read_ptx_sreg_tid_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_nctaid, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
// normal cpu functions can be kernels
}
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
static std::array<const Argument*, 3> ids = {
fn->arg_begin() + num_params - 3,
fn->arg_begin() + num_params - 2,
fn->arg_begin() + num_params - 1
};
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;
}
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
return builder.getInt32(0);
}
}
}

View File

@@ -0,0 +1,143 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
ir::value *ptr = i->get_pointer_operand();
auto contiguous = align_->contiguous(ptr);
auto it = std::max_element(contiguous.begin(), contiguous.end());
int axis = std::distance(contiguous.begin(), it);
result[axis].push_back(i);
}
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
std::map<ir::value*, ir::value*>& seen) {
if(seen.find(x) != seen.end())
return seen.at(x);
auto i = dynamic_cast<ir::instruction*>(x);
// not an instruction -- forward value
if(!i)
return x;
// already in shared memory -- forward value
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
return x;
}
// set insert point
auto& inst_list = i->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(x)){
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x));
return ret;
}
// default -- recursive clone
ir::instruction *cloned = builder.insert(i->clone());
seen[i] = cloned;
// rematerialize operands
for(ir::value *op: cloned->ops())
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
return cloned;
}
void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) {
if(!layout_->get(id)->to_mma884())
continue;
// extract memory stores
const auto& values = layout_->values_of(id);
ir::value* dot = nullptr;
for(ir::value *v: values)
if(auto x = dynamic_cast<ir::dot_inst*>(v))
dot = x;
ir::builder& builder = mod.get_builder();
std::vector<ir::value*> worklist = {dot};
std::set<ir::value*> seen;
while(!worklist.empty()) {
ir::value *current = worklist.back();
seen.insert(current);
worklist.pop_back();
// stop if trunc
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
builder.set_insert_point_after(x);
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
builder.insert(rc);
x->replace_all_uses_with(rc);
rc->replace_uses_of_with(rc, x);
break;
}
// recurse
for(ir::user *u: current->get_users())
if(seen.find(u) == seen.end())
worklist.push_back(u);
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
for(ir::value *v: values)
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
extract_ld(i, axes);
}
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
remat.insert(remat.begin(), it->second.begin(), it->second.end());
}
// rematerialize values
for(ir::io_inst *r: remat) {
ir::builder& builder = mod.get_builder();
// rematerialize operands
std::map<ir::value*, ir::value*> seen;
for(ir::value *op: r->ops())
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
// copy to shared if load
auto& inst_list = r->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(r)){
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
r->replace_all_uses_with(cts);
cts->replace_uses_of_with(cts, r);
}
}
}
}
}
}

View File

@@ -0,0 +1,95 @@
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
return op==0;
return false;
}
inline bool is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
if(i->get_id() == ir::INST_TRANS)
return true;
if(i->get_id() == ir::INST_REDUCE)
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
return false;
}
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
// already in shared memory
if(to_shared && is_shmem_res(i))
return;
// copy
builder.set_insert_point_after(i);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k))
add_copy(i, i->get_operand(k), builder, true);
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
}
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More